""" crypto.cipher.rijndael

    Rijndael encryption algorithm

    This byte oriented implementation is intended to closely
    match FIPS specification for readability.  It is not implemented
    for performance.

    Copyright (c) 2002 by Paul A. Lambert
    Read LICENSE.txt for license information.

    2002-06-01
"""

from crypto.cipher.base import BlockCipher, padWithPadLen, noPadding

class Rijndael(BlockCipher):
    """ Rijndael encryption algorithm """
    def __init__(self, key = None, padding = padWithPadLen(), keySize=16, blockSize=16 ):
        self.name       = 'RIJNDAEL'
        self.keySize    = keySize
        self.strength   = keySize*8
        self.blockSize  = blockSize  # blockSize is in bytes
        self.padding    = padding    # change default to noPadding() to get normal ECB behavior

        assert( keySize%4==0 and NrTable[4].has_key(keySize/4)),'key size must be 16,20,24,29 or 32 bytes'
        assert( blockSize%4==0 and NrTable.has_key(blockSize/4)), 'block size must be 16,20,24,29 or 32 bytes'

        self.Nb = self.blockSize/4          # Nb is number of columns of 32 bit words
        self.Nk = keySize/4                 # Nk is the key length in 32-bit words
        self.Nr = NrTable[self.Nb][self.Nk] # The number of rounds (Nr) is a function of
                                            # the block (Nb) and key (Nk) sizes.
        if key != None:
            self.setKey(key)

    def setKey(self, key):
        """ Set a key and generate the expanded key """
        assert( len(key) == (self.Nk*4) ), 'Key length must be same as keySize parameter'
        self.__expandedKey = keyExpansion(self, key)
        self.reset()                   # BlockCipher.reset()

    def encryptBlock(self, plainTextBlock):
        """ Encrypt a block, plainTextBlock must be a array of bytes [Nb by 4] """
        self.state = self._toBlock(plainTextBlock)
        AddRoundKey(self, self.__expandedKey[0:self.Nb])
        for round in range(1,self.Nr):          #for round = 1 step 1 to Nr1
            SubBytes(self)
            ShiftRows(self)
            MixColumns(self)
            AddRoundKey(self, self.__expandedKey[round*self.Nb:(round+1)*self.Nb])
        SubBytes(self)
        ShiftRows(self)
        AddRoundKey(self, self.__expandedKey[self.Nr*self.Nb:(self.Nr+1)*self.Nb])
        return self._toBString(self.state)


    def decryptBlock(self, encryptedBlock):
        """ decrypt a block (array of bytes) """
        self.state = self._toBlock(encryptedBlock)
        AddRoundKey(self, self.__expandedKey[self.Nr*self.Nb:(self.Nr+1)*self.Nb])
        for round in range(self.Nr-1,0,-1):
            InvShiftRows(self)
            InvSubBytes(self)
            AddRoundKey(self, self.__expandedKey[round*self.Nb:(round+1)*self.Nb])
            InvMixColumns(self)
        InvShiftRows(self)
        InvSubBytes(self)
        AddRoundKey(self, self.__expandedKey[0:self.Nb])
        return self._toBString(self.state)

    def _toBlock(self, bs):
        """ Convert binary string to array of bytes, state[col][row]"""
        assert ( len(bs) == 4*self.Nb ), 'Rijndarl blocks must be of size blockSize'
        return [[ord(bs[4*i]),ord(bs[4*i+1]),ord(bs[4*i+2]),ord(bs[4*i+3])] for i in range(self.Nb)]

    def _toBString(self, block):
        """ Convert block (array of bytes) to binary string """
        l = []
        for col in block:
            for rowElement in col:
                l.append(chr(rowElement))
        return ''.join(l)
#-------------------------------------
"""    Number of rounds Nr = NrTable[Nb][Nk]

            Nb  Nk=4   Nk=5   Nk=6   Nk=7   Nk=8
            -------------------------------------   """
NrTable =  {4: {4:10,  5:11,  6:12,  7:13,  8:14},
            5: {4:11,  5:11,  6:12,  7:13,  8:14},
            6: {4:12,  5:12,  6:12,  7:13,  8:14},
            7: {4:13,  5:13,  6:13,  7:13,  8:14},
            8: {4:14,  5:14,  6:14,  7:14,  8:14}}
#-------------------------------------
def keyExpansion(algInstance, keyString):
    """ Expand a string of size keySize into a larger array """
    Nk, Nb, Nr = algInstance.Nk, algInstance.Nb, algInstance.Nr # for readability
    key = [ord(byte) for byte in keyString]  # convert string to list
    w = [[key[4*i],key[4*i+1],key[4*i+2],key[4*i+3]] for i in range(Nk)]
    for i in range(Nk,Nb*(Nr+1)):
        temp = w[i-1]        # a four byte column
        if (i%Nk) == 0 :
            temp     = temp[1:]+[temp[0]]  # RotWord(temp)
            temp     = [ Sbox[byte] for byte in temp ]
            temp[0] ^= Rcon[i/Nk]
        elif Nk > 6 and  i%Nk == 4 :
            temp     = [ Sbox[byte] for byte in temp ]  # SubWord(temp)
        w.append( [ w[i-Nk][byte]^temp[byte] for byte in range(4) ] )
    return w

Rcon = (0,0x01,0x02,0x04,0x08,0x10,0x20,0x40,0x80,0x1b,0x36,     # note extra '0' !!!
        0x6c,0xd8,0xab,0x4d,0x9a,0x2f,0x5e,0xbc,0x63,0xc6,
        0x97,0x35,0x6a,0xd4,0xb3,0x7d,0xfa,0xef,0xc5,0x91)

#-------------------------------------
def AddRoundKey(algInstance, keyBlock):
    """ XOR the algorithm state with a block of key material """
    for column in range(algInstance.Nb):
        for row in range(4):
            algInstance.state[column][row] ^= keyBlock[column][row]
#-------------------------------------

def SubBytes(algInstance):
    for column in range(algInstance.Nb):
        for row in range(4):
            algInstance.state[column][row] = Sbox[algInstance.state[column][row]]

def InvSubBytes(algInstance):
    for column in range(algInstance.Nb):
        for row in range(4):
            algInstance.state[column][row] = InvSbox[algInstance.state[column][row]]

Sbox =    (0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,
           0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76,
           0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,
           0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0,
           0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,
           0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15,
           0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,
           0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75,
           0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,
           0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84,
           0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,
           0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf,
           0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,
           0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8,
           0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,
           0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2,
           0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,
           0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73,
           0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,
           0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb,
           0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,
           0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79,
           0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,
           0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08,
           0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,
           0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a,
           0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,
           0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e,
           0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,
           0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf,
           0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,
           0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16)

InvSbox = (0x52,0x09,0x6a,0xd5,0x30,0x36,0xa5,0x38,
           0xbf,0x40,0xa3,0x9e,0x81,0xf3,0xd7,0xfb,
           0x7c,0xe3,0x39,0x82,0x9b,0x2f,0xff,0x87,
           0x34,0x8e,0x43,0x44,0xc4,0xde,0xe9,0xcb,
           0x54,0x7b,0x94,0x32,0xa6,0xc2,0x23,0x3d,
           0xee,0x4c,0x95,0x0b,0x42,0xfa,0xc3,0x4e,
           0x08,0x2e,0xa1,0x66,0x28,0xd9,0x24,0xb2,
           0x76,0x5b,0xa2,0x49,0x6d,0x8b,0xd1,0x25,
           0x72,0xf8,0xf6,0x64,0x86,0x68,0x98,0x16,
           0xd4,0xa4,0x5c,0xcc,0x5d,0x65,0xb6,0x92,
           0x6c,0x70,0x48,0x50,0xfd,0xed,0xb9,0xda,
           0x5e,0x15,0x46,0x57,0xa7,0x8d,0x9d,0x84,
           0x90,0xd8,0xab,0x00,0x8c,0xbc,0xd3,0x0a,
           0xf7,0xe4,0x58,0x05,0xb8,0xb3,0x45,0x06,
           0xd0,0x2c,0x1e,0x8f,0xca,0x3f,0x0f,0x02,
           0xc1,0xaf,0xbd,0x03,0x01,0x13,0x8a,0x6b,
           0x3a,0x91,0x11,0x41,0x4f,0x67,0xdc,0xea,
           0x97,0xf2,0xcf,0xce,0xf0,0xb4,0xe6,0x73,
           0x96,0xac,0x74,0x22,0xe7,0xad,0x35,0x85,
           0xe2,0xf9,0x37,0xe8,0x1c,0x75,0xdf,0x6e,
           0x47,0xf1,0x1a,0x71,0x1d,0x29,0xc5,0x89,
           0x6f,0xb7,0x62,0x0e,0xaa,0x18,0xbe,0x1b,
           0xfc,0x56,0x3e,0x4b,0xc6,0xd2,0x79,0x20,
           0x9a,0xdb,0xc0,0xfe,0x78,0xcd,0x5a,0xf4,
           0x1f,0xdd,0xa8,0x33,0x88,0x07,0xc7,0x31,
           0xb1,0x12,0x10,0x59,0x27,0x80,0xec,0x5f,
           0x60,0x51,0x7f,0xa9,0x19,0xb5,0x4a,0x0d,
           0x2d,0xe5,0x7a,0x9f,0x93,0xc9,0x9c,0xef,
           0xa0,0xe0,0x3b,0x4d,0xae,0x2a,0xf5,0xb0,
           0xc8,0xeb,0xbb,0x3c,0x83,0x53,0x99,0x61,
           0x17,0x2b,0x04,0x7e,0xba,0x77,0xd6,0x26,
           0xe1,0x69,0x14,0x63,0x55,0x21,0x0c,0x7d)

#-------------------------------------
""" For each block size (Nb), the ShiftRow operation shifts row i
    by the amount Ci.  Note that row 0 is not shifted.
                 Nb      C1 C2 C3
               -------------------  """
shiftOffset  = { 4 : ( 0, 1, 2, 3),
                 5 : ( 0, 1, 2, 3),
                 6 : ( 0, 1, 2, 3),
                 7 : ( 0, 1, 2, 4),
                 8 : ( 0, 1, 3, 4) }
def ShiftRows(algInstance):
    tmp = [0]*algInstance.Nb   # list of size Nb
    for r in range(1,4):       # row 0 reamains unchanged and can be skipped
        for c in range(algInstance.Nb):
            tmp[c] = algInstance.state[(c+shiftOffset[algInstance.Nb][r]) % algInstance.Nb][r]
        for c in range(algInstance.Nb):
            algInstance.state[c][r] = tmp[c]
def InvShiftRows(algInstance):
    tmp = [0]*algInstance.Nb   # list of size Nb
    for r in range(1,4):       # row 0 reamains unchanged and can be skipped
        for c in range(algInstance.Nb):
            tmp[c] = algInstance.state[(c+algInstance.Nb-shiftOffset[algInstance.Nb][r]) % algInstance.Nb][r]
        for c in range(algInstance.Nb):
            algInstance.state[c][r] = tmp[c]
#-------------------------------------
def MixColumns(a):
    Sprime = [0,0,0,0]
    for j in range(a.Nb):    # for each column
        Sprime[0] = mul(2,a.state[j][0])^mul(3,a.state[j][1])^mul(1,a.state[j][2])^mul(1,a.state[j][3])
        Sprime[1] = mul(1,a.state[j][0])^mul(2,a.state[j][1])^mul(3,a.state[j][2])^mul(1,a.state[j][3])
        Sprime[2] = mul(1,a.state[j][0])^mul(1,a.state[j][1])^mul(2,a.state[j][2])^mul(3,a.state[j][3])
        Sprime[3] = mul(3,a.state[j][0])^mul(1,a.state[j][1])^mul(1,a.state[j][2])^mul(2,a.state[j][3])
        for i in range(4):
            a.state[j][i] = Sprime[i]

def InvMixColumns(a):
    """ Mix the four bytes of every column in a linear way
        This is the opposite operation of Mixcolumn """
    Sprime = [0,0,0,0]
    for j in range(a.Nb):    # for each column
        Sprime[0] = mul(0x0E,a.state[j][0])^mul(0x0B,a.state[j][1])^mul(0x0D,a.state[j][2])^mul(0x09,a.state[j][3])
        Sprime[1] = mul(0x09,a.state[j][0])^mul(0x0E,a.state[j][1])^mul(0x0B,a.state[j][2])^mul(0x0D,a.state[j][3])
        Sprime[2] = mul(0x0D,a.state[j][0])^mul(0x09,a.state[j][1])^mul(0x0E,a.state[j][2])^mul(0x0B,a.state[j][3])
        Sprime[3] = mul(0x0B,a.state[j][0])^mul(0x0D,a.state[j][1])^mul(0x09,a.state[j][2])^mul(0x0E,a.state[j][3])
        for i in range(4):
            a.state[j][i] = Sprime[i]

#-------------------------------------
def mul(a, b):
    """ Multiply two elements of GF(2^m)
        needed for MixColumn and InvMixColumn """
    if (a !=0 and  b!=0):
        return Alogtable[(Logtable[a] + Logtable[b])%255]
    else:
        return 0

Logtable = ( 0,   0,  25,   1,  50,   2,  26, 198,  75, 199,  27, 104,  51, 238, 223,   3,
           100,   4, 224,  14,  52, 141, 129, 239,  76, 113,   8, 200, 248, 105,  28, 193,
           125, 194,  29, 181, 249, 185,  39, 106,  77, 228, 166, 114, 154, 201,   9, 120,
           101,  47, 138,   5,  33,  15, 225,  36,  18, 240, 130,  69,  53, 147, 218, 142,
           150, 143, 219, 189,  54, 208, 206, 148,  19,  92, 210, 241,  64,  70, 131,  56,
           102, 221, 253,  48, 191,   6, 139,  98, 179,  37, 226, 152,  34, 136, 145,  16,
           126, 110,  72, 195, 163, 182,  30,  66,  58, 107,  40,  84, 250, 133,  61, 186,
            43, 121,  10,  21, 155, 159,  94, 202,  78, 212, 172, 229, 243, 115, 167,  87,
           175,  88, 168,  80, 244, 234, 214, 116,  79, 174, 233, 213, 231, 230, 173, 232,
            44, 215, 117, 122, 235,  22,  11, 245,  89, 203,  95, 176, 156, 169,  81, 160,
           127,  12, 246, 111,  23, 196,  73, 236, 216,  67,  31,  45, 164, 118, 123, 183,
           204, 187,  62,  90, 251,  96, 177, 134,  59,  82, 161, 108, 170,  85,  41, 157,
           151, 178, 135, 144,  97, 190, 220, 252, 188, 149, 207, 205,  55,  63,  91, 209,
            83,  57, 132,  60,  65, 162, 109,  71,  20,  42, 158,  93,  86, 242, 211, 171,
            68,  17, 146, 217,  35,  32,  46, 137, 180, 124, 184,  38, 119, 153, 227, 165,
           103,  74, 237, 222, 197,  49, 254,  24,  13,  99, 140, 128, 192, 247, 112,   7)

Alogtable= ( 1,   3,   5,  15,  17,  51,  85, 255,  26,  46, 114, 150, 161, 248,  19,  53,
            95, 225,  56,  72, 216, 115, 149, 164, 247,   2,   6,  10,  30,  34, 102, 170,
           229,  52,  92, 228,  55,  89, 235,  38, 106, 190, 217, 112, 144, 171, 230,  49,
            83, 245,   4,  12,  20,  60,  68, 204,  79, 209, 104, 184, 211, 110, 178, 205,
            76, 212, 103, 169, 224,  59,  77, 215,  98, 166, 241,   8,  24,  40, 120, 136,
           131, 158, 185, 208, 107, 189, 220, 127, 129, 152, 179, 206,  73, 219, 118, 154,
           181, 196,  87, 249,  16,  48,  80, 240,  11,  29,  39, 105, 187, 214,  97, 163,
           254,  25,  43, 125, 135, 146, 173, 236,  47, 113, 147, 174, 233,  32,  96, 160,
           251,  22,  58,  78, 210, 109, 183, 194,  93, 231,  50,  86, 250,  21,  63,  65,
           195,  94, 226,  61,  71, 201,  64, 192,  91, 237,  44, 116, 156, 191, 218, 117,
           159, 186, 213, 100, 172, 239,  42, 126, 130, 157, 188, 223, 122, 142, 137, 128,
           155, 182, 193,  88, 232,  35, 101, 175, 234,  37, 111, 177, 200,  67, 197,  84,
           252,  31,  33,  99, 165, 244,   7,   9,  27,  45, 119, 153, 176, 203,  70, 202,
            69, 207,  74, 222, 121, 139, 134, 145, 168, 227,  62,  66, 198,  81, 243,  14,
            18,  54,  90, 238,  41, 123, 141, 140, 143, 138, 133, 148, 167, 242,  13,  23,
            57,  75, 221, 124, 132, 151, 162, 253,  28,  36, 108, 180, 199,  82, 246,   1)


