 # Tiny AES in Python

Recently for an Applied Cryptography class at KTH I had to write from scratch an implementation of the AES encryption. AES is a block cypher encryption that operates on blocks of different sizes. I implement AES-128 in ECB mode which takes 16 byte blocks in a matrix of 4 by 4. What follows is a small implementation in Python with some explanations around each AES step.

### Overall structure of AES AES is quite a big and detailed process. I won’t go into the details of how each round works as that would require a series of posts instead of one post. Maybe in the future I’ll expand on this. Here is the high-level structure of the algorithm that I got from Wikipedia –

• Key Expansion
• 9 rounds of –
• Sub Bytes
• Shift Rows
• Mix Columns
• Final round of –
• Sub Bytes
• Shift Rows

### Implementation

This particular implementation was written to work with Kattis automated testing so it reads stdin as binary data and outputs the encrypted values to stdout in binary data. It probably won’t be a big change to switch this to standard Hex values for better readability but I haven’t gotten around to doing this since the class ended.

### Formatting Inputs

The functions below help to convert the 16 byte strings into a list and eventually into a matrix. The transpose function is also needed later in the main encryption function.

``````    def space_input(text):
""" Converts a string into 2 character hex list  """
state = []
while text:
state.append(text[:2])
text = text[2:]
return state

def bytes_2_matrix(text):
""" Converts a 16-byte array into a 4x4 nested list matrix  """
state = [list(text[i:i+4]) for i in range(0, len(text), 4)]
state = transpose_matrix(state)
return state

def transpose_matrix(state):
""" bytes 2 matrix result needs to be transposed for correct order """
temp = []
for j in range(NUM_BLOCKS):
row = []
for i in range(NUM_BLOCKS):
val = state[i][j]
if len(val) == 3:
state[i][j] = '0x0' + val
row.append(state[i][j])
temp.append(row)
return temp
``````

### Initial Key Expansion

For key expansion, we take the 16 byte key that we have and expand it to 176 bytes. Each of these is then passed into each round of AES as a round key.

``````    def rotate_bytes(bytes):
temp = bytes
bytes = bytes
bytes = bytes
bytes = bytes
bytes = temp
return bytes

def sbox_lookup(bytes):
for i in range(NUM_BLOCKS):
srow = int(bytes[i],16) // 16
scol = int(bytes[i],16) % 16
hex_sbox = SBOX[srow][scol]
bytes[i] = hex(hex_sbox)
return bytes

def expand_key(master_key, expanded_key):
rcon = 1
last_four_col = [0 for x in range(NUM_BLOCKS)]

for i in range(0, len(master_key), 1):
expanded_key[i] = hex(int(str(master_key[i]),16))

done_with = 16  # counter for how many bytes we're done expanding
total_expanded_bytes = 176

while done_with < total_expanded_bytes:
for i in range(NUM_BLOCKS):
last_four_col[i] = expanded_key[i + done_with - 4]

if (done_with % 16 == 0):
last_four_col = rotate_bytes(last_four_col)
last_four_col = sbox_lookup(last_four_col)
last_four_col = hex(int(last_four_col,16) ^ RCON_TABLE[rcon])
rcon = rcon + 1
for i in range(NUM_BLOCKS):
expanded_key[done_with] = hex(int(str(expanded_key[done_with - 16]),16) ^ int(str(last_four_col[i]),16))
done_with = done_with + 1

return expanded_key
``````

Here we just xor the byte of the state with the byte of the key. State here refers to our 4×4 matrix that we are trying to encrypt.

``````    def add_round_key(state, key):
key_matrix = bytes_2_matrix(key) # convert list into matrix for easier xor order
for row in range(NUM_BLOCKS):
for col in range(NUM_BLOCKS):
state[row][col] = hex(int(str(state[row][col]), 16) ^ int(str(key_matrix[row][col]), 16))
return state
``````

### Sub Bytes

Replace each byte of the state with the corresponding in the lookup table (SBOX).

``````    def sub_bytes(state):
for i in range(NUM_BLOCKS):
for j in range(NUM_BLOCKS):
val = int(state[i][j], 16)
state[i][j] = hex(SBOX[val//16][val%16])
return state
``````

### Shift Rows

The first row is left as is. The second row is shifted left twice. The third row is shifted left thrice. The last row is shifted left three times.

``````    def shift_rows(state, row, num_shifts):
for _ in range(num_shifts):
temp = state[row]
for j in range(3):
state[row][j] = state[row][j+1]
state[row] = temp
return state

def shift_rows_wrapper(state):
for i in range(1, 4):
state = shift_rows(state, i, i)
return state
``````

### Mix Columns

Taking the last column at a time and using matrix multiplication on it with an xor function. This is explained better in many other places, including here.

``````    def mix_columns(state):
mixed_state = [
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]
]
col_pick = [x for x in range(NUM_BLOCKS)]
mix_col_val = [x for x in range(NUM_BLOCKS)]
for col in range(NUM_BLOCKS):
for row in range(NUM_BLOCKS):
col_pick[row] = int(state[row][col], 16)

mix_col_val = MULTIPLY2[col_pick//16][col_pick%16] ^ MULTIPLY3[col_pick//16][col_pick%16] ^ col_pick ^ col_pick
mix_col_val = col_pick ^ MULTIPLY2[col_pick//16][col_pick%16] ^ MULTIPLY3[col_pick//16][col_pick%16] ^ col_pick
mix_col_val = col_pick ^ col_pick ^ MULTIPLY2[col_pick//16][col_pick%16] ^ MULTIPLY3[col_pick//16][col_pick%16]
mix_col_val = MULTIPLY3[col_pick//16][col_pick%16] ^ col_pick ^ col_pick ^ MULTIPLY2[col_pick//16][col_pick%16]

for index in range(NUM_BLOCKS):
mixed_state[index][col] = hex(mix_col_val[index])

return mixed_state
``````

### Encryption

The code above is all the functions we need. The functions just repeat over each round with new round keys from our key expansaion. We bring it all together in our main encryption function below.

``````    def encrypt(message, keyinput):

key = keyinput
message = space_input(data)
state_matrix = bytes_2_matrix(message)

EXPANDED_KEY = [0 for x in range(176)]
expanded_key = expand_key(key, EXPANDED_KEY)

round_key = [0 for x in range(len(expanded_key)//16)]
for i in range(len(round_key)):
round_key[i] = expanded_key[i*16:i*16+16]

for i in range(1, NUM_ROUNDS):
state_matrix = sub_bytes(state_matrix)
state_matrix = shift_rows_wrapper(state_matrix)
state_matrix = mix_columns(state_matrix)

state_matrix = sub_bytes(state_matrix)
state_matrix = shift_rows_wrapper(state_matrix)

#check the encryption string
state_matrix = transpose_matrix(state_matrix)
cipher_list = [i for sub in state_matrix for i in sub] # flatten nested list into 1D list
cipher = "".join(cipher_list).replace('0x','')
return cipher
``````

### Main Function

The main function where we obtain the stdin and send the stdout as well as break down the input by 16 bytes at a time for better performance.

``````    if __name__ == '__main__':
#For Kattis
key = space_input(key)

while True: