# n1AES

看加密函数

def encrypt(self, plaintext):
        self.plain_state = text2matrix(plaintext)
        self.add_round_key(self.plain_state, self.round_keys[:4])
        for i in range(1, 10):
            self.round_encrypt(self.plain_state, self.round_keys[4 * i: 4 * (i + 1)])
        self.sub_bytes(self.plain_state)
        self.shift_rows(self.plain_state)
        self.add_round_key(self.plain_state, self.round_keys[40:])
        return matrix2text(self.plain_state)

那么解密函数就是

def decrypt(self, plaintext):
            self.plain_state = text2matrix(plaintext)
            self.add_round_key(self.plain_state, self.round_keys[40:])
            self._shift_rows(self.plain_state)
            self._sub_bytes(self.plain_state)
            for i in range(9, 0,-1):
                self.round_decrypt(self.plain_state, self.round_keys[4 * i: 4 * (i + 1)])
            self.add_round_key(self.plain_state, self.round_keys[:4])
            return matrix2text(self.plain_state)

shift_rows 的逆函数

def _shift_rows(self, s):
       r = [14, 12, 10, 11, 3, 2, 0, 13, 9, 5, 4, 8, 15, 1, 6, 7]
       t = []
       for i in xrange(4):
           for j in xrange(4):
               t.append(s[i][j])
       for i in xrange(16):
           s[i/4][i%4] = t[r.index(i)]

sub_bytes 的逆函数

def _sub_bytes(self, s):
       for i in range(4):
           for j in range(4):
               s[i][j] = InvSbox[s[i][j]]

round_decrypt 函数

def round_decrypt(self, state_matrix, key_matrix):
       self.add_round_key(state_matrix, key_matrix)
       self._mix_columns(state_matrix)
       self._shift_rows(state_matrix)
       self._sub_bytes(state_matrix)

主要是列混淆的逆函数

def _mix_columns(self, s):
       wt = [99, 214, 252, 197]
       for i in range(4):
           s0 = xx(s[0][i],wt[0]) ^ xx(s[1][i],wt[1]) ^ xx(s[2][i],wt[2]) ^ xx(s[3][i],wt[3])
           s1 = xx(s[0][i],wt[1]) ^ xx(s[1][i],wt[2]) ^ xx(s[2][i],wt[3]) ^ xx(s[3][i],wt[0])
           s2 = xx(s[0][i],wt[2]) ^ xx(s[1][i],wt[3]) ^ xx(s[2][i],wt[0]) ^ xx(s[3][i],wt[1])
           s3 = xx(s[0][i],wt[3]) ^ xx(s[1][i],wt[0]) ^ xx(s[2][i],wt[1]) ^ xx(s[3][i],wt[2])
           s[0][i] = s0
           s[1][i] = s1
           s[2][i] = s2
           s[3][i] = s3

主要是需要算出 [3, 10, 8, 6] 在模 P 多项式 P (x) = x^8 + x^6 + x^2 + 1 下的逆

不太懂怎么算

exp

def xor(a, b):
  return ''.join(chr(ord(ac) ^ ord(bc)) for ac, bc in zip(a, b))
Sbox = (
  0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
  0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
  0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
  0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
  0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
  0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
  0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
  0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
  0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
  0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
  0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
  0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
  0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
  0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
  0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
  0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16,
)
InvSbox = (
  0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB,
  0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB,
  0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E,
  0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25,
  0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92,
  0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84,
  0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06,
  0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B,
  0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73,
  0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E,
  0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B,
  0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4,
  0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F,
  0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF,
  0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61,
  0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D,
)
xtime = lambda a: (((a << 1) ^ 99) & 0xFF) if (a & 0x80) else (a << 1)
Rcon = [0x0,0x1,0x2,0x4,0x8,0x10,0x20,0x40,0x80,0x1b,0x36,0x6c,0xd8,0xab,0x4d,0x9a,0x2f,0x5e,0xbc,0x63,0xc6,0x97,0x35,0x6a,0xd4,0xb3,0x7d,0xfa,0xef,0xc5,0x91,0x39]
def xx(i,j):
  t = i
  res = 0
  while j != 0:
      if j & 1:
          res ^= t
      t = xtime(t)
      j >>= 1
  return res
def text2matrix(text):
  matrix = []
  for i in range(16):
      byte = ord(text[i])
      if i % 4 == 0:
          matrix.append([byte])
      else:
          matrix[i / 4].append(byte)
  return matrix
def matrix2text(matrix):
  text = ''
  for i in range(4):
      for j in range(4):
          text = text + chr(matrix[i][j])
  return text
class N1AES:
  def __init__(self, master_key=None):
      if master_key:
          self.change_key(master_key)
  def change_key(self, master_key):
      self.skey = int(master_key.encode('hex'),16)
      self.round_keys = text2matrix(master_key)
      for i in range(4, 4 * 11):
          self.round_keys.append([])
          if i % 4 == 0:
              byte = self.round_keys[i - 4][0] \
                     ^ Sbox[self.round_keys[i - 1][1]] \
                     ^ Rcon[i / 4]
              self.round_keys[i].append(byte)
              for j in range(1, 4):
                  byte = self.round_keys[i - 4][j] \
                         ^ Sbox[self.round_keys[i - 1][(j + 1) % 4]]
                  self.round_keys[i].append(byte)
          else:
              for j in range(4):
                  byte = self.round_keys[i - 4][j] \
                         ^ self.round_keys[i - 1][j]
                  self.round_keys[i].append(byte)
  def encrypt(self, plaintext):
      self.plain_state = text2matrix(plaintext)
      self.add_round_key(self.plain_state, self.round_keys[:4])
      for i in range(1, 10):
          self.round_encrypt(self.plain_state, self.round_keys[4 * i: 4 * (i + 1)])
      self.sub_bytes(self.plain_state)
      self.shift_rows(self.plain_state)
      self.add_round_key(self.plain_state, self.round_keys[40:])
      return matrix2text(self.plain_state)
  def add_round_key(self, s, k):
      for i in range(4):
          for j in range(4):
              s[i][j] ^= k[i][j]
  def round_encrypt(self, state_matrix, key_matrix):
      self.sub_bytes(state_matrix)
      self.shift_rows(state_matrix)
      self.mix_columns(state_matrix)
      self.add_round_key(state_matrix, key_matrix)
  def sub_bytes(self, s):
      for i in range(4):
          for j in range(4):
              s[i][j] = Sbox[s[i][j]]
  def shift_rows(self, s):
      r = [14, 12, 10, 11, 3, 2, 0, 13, 9, 5, 4, 8, 15, 1, 6, 7]
      t = []
      for i in xrange(4):
          for j in xrange(4):
              t.append(s[i][j])
      for i in xrange(16):
          s[i/4][i%4] = t[r[i]]
  def mix_columns(self, s):
      wt = [3, 10, 8, 6]
      for i in range(4):
          s0 = xx(s[0][i],wt[0]) ^ xx(s[1][i],wt[1]) ^ xx(s[2][i],wt[2]) ^ xx(s[3][i],wt[3])
          s1 = xx(s[0][i],wt[1]) ^ xx(s[1][i],wt[2]) ^ xx(s[2][i],wt[3]) ^ xx(s[3][i],wt[0])
          s2 = xx(s[0][i],wt[2]) ^ xx(s[1][i],wt[3]) ^ xx(s[2][i],wt[0]) ^ xx(s[3][i],wt[1])
          s3 = xx(s[0][i],wt[3]) ^ xx(s[1][i],wt[0]) ^ xx(s[2][i],wt[1]) ^ xx(s[3][i],wt[2])
          s[0][i] = s0
          s[1][i] = s1
          s[2][i] = s2
          s[3][i] = s3
  def decrypt(self, plaintext):
          self.plain_state = text2matrix(plaintext)
          self.add_round_key(self.plain_state, self.round_keys[40:])
          self._shift_rows(self.plain_state)
          self._sub_bytes(self.plain_state)
          for i in range(9, 0,-1):
              self.round_decrypt(self.plain_state, self.round_keys[4 * i: 4 * (i + 1)])
          self.add_round_key(self.plain_state, self.round_keys[:4])
          return matrix2text(self.plain_state)
  def _shift_rows(self, s):
      r = [14, 12, 10, 11, 3, 2, 0, 13, 9, 5, 4, 8, 15, 1, 6, 7]
      t = []
      for i in xrange(4):
          for j in xrange(4):
              t.append(s[i][j])
      for i in xrange(16):
          s[i/4][i%4] = t[r.index(i)]
  def _sub_bytes(self, s):
      for i in range(4):
          for j in range(4):
              s[i][j] = InvSbox[s[i][j]]
  def round_decrypt(self, state_matrix, key_matrix):
      self.add_round_key(state_matrix, key_matrix)
      self._mix_columns(state_matrix)
      self._shift_rows(state_matrix)
      self._sub_bytes(state_matrix)
  def _mix_columns(self, s):
      wt = [99, 214, 252, 197]
      for i in range(4):
          s0 = xx(s[0][i],wt[0]) ^ xx(s[1][i],wt[1]) ^ xx(s[2][i],wt[2]) ^ xx(s[3][i],wt[3])
          s1 = xx(s[0][i],wt[1]) ^ xx(s[1][i],wt[2]) ^ xx(s[2][i],wt[3]) ^ xx(s[3][i],wt[0])
          s2 = xx(s[0][i],wt[2]) ^ xx(s[1][i],wt[3]) ^ xx(s[2][i],wt[0]) ^ xx(s[3][i],wt[1])
          s3 = xx(s[0][i],wt[3]) ^ xx(s[1][i],wt[0]) ^ xx(s[2][i],wt[1]) ^ xx(s[3][i],wt[2])
          s[0][i] = s0
          s[1][i] = s1
          s[2][i] = s2
          s[3][i] = s3
a = N1AES("THEAESPARTN1BOOK")
#b = a.encrypt(flag[:16]) +  a.encrypt(flag[16:])
b='588aa4c53819273bd2cdd6a20de7453ca21ef63d75077daa42b30e7fad50b39f'
b=b.decode('hex')
c=a.decrypt(b[:16])+a.decrypt(b[16:])
print(c)

# [SWPU 2020]cbc1

# 题目

from Crypto.Cipher import AES
import os
flag = os.environ['FLAG']
BLOCKSIZE = 16
def pad(data):
        pad_len = BLOCKSIZE - (len(data) % BLOCKSIZE) if  len(data) % BLOCKSIZE != 0 else 0
        return data + "=" * pad_len
def unpad(data):
        return data.replace("=","")
def enc(data,key,iv):
	cipher = AES.new(key,AES.MODE_CBC,iv)
	encrypt = cipher.encrypt(pad(data))
	return encrypt
def dec(data,key,iv):
	try:
		cipher = AES.new(key,AES.MODE_CBC,iv)
		encrypt = cipher.decrypt(data)
		return unpad(encrypt)
	except:
		exit()
def task():
        try:
                key = os.urandom(16)
                iv = os.urandom(16)
                pre = "yusa"*4
                for _ in range(3):
                        choice = raw_input(menu)
                        if choice == '1':
                                name = raw_input("What's your name?")
                                if name == 'admin':
                                        exit()
                                token = enc(pre+name,key,iv)
                                print "Here is your token(in hex): "+iv.encode('hex')+token.encode('hex')
                                continue
                        elif choice == '2':
                                token = raw_input("Your token(in hex): ").decode('hex')
                                iv = token[:16]
                                name = dec(token[16:],key,iv)
                                print iv.encode('hex')+name.encode('hex')
                                if name[:16] == "yusa"*4:
                                        print "Hello, "+name[16:]
                                        if name[16:] == 'admin':
                                                print flag
                                                exit()
                        else:
                                continue
        except:
                exit()
menu='''
1. register
2. login
3. exit
'''
if __name__ == "__main__":
        task()

m0=b'yusayusayusayusayusa'

m1=b'hhhhhhhhhhhhhhhhhhhh'

c0=E(m0^iv)

c1=E(m1^c0)

m1'=b'admin'=D(c1)^c0^m1^admin=m1^m1^admin

c0'=c0^m1^admin

D(c0')^iv=m'

m'^iv=D(c0')=>D(c0')^D(c0')^yusa*4=yusa*4

iv'=m'^yusa*4^iv

# exp

from pwn import *
from Crypto.Util.number import *
re=remote("1.14.71.254",28645)
re.recvuntil(b'3. exit')
re.sendline(b'1')
re.recvuntil(b"What's your name?")
name='haha'*4
aa='yusa'*4
re.sendline(name.encode())
re.recvuntil(b'Here is your token(in hex): ')
c=re.recvline().decode().strip()
iv=c[:32]
c0=c[32:64]
c1=c[64:]
re.recvuntil(b'3. exit')
re.sendline(b'2')
re.recvuntil(b'Your token(in hex): ')
c00=int(c0,16)^bytes_to_long(name.encode())^bytes_to_long(b'===========admin')
pl=iv+hex(c00)[2:]
re.sendline(pl.encode())
mm=re.recvline().decode().strip()[32:]
iv=hex(int(mm,16)^bytes_to_long(aa.encode())^int(iv,16))[2:]
c=hex(c00)[2:]+c1
re.recvuntil(b'3. exit')
re.sendline(b'2')
re.recvuntil(b'Your token(in hex): ')
p=iv+c
re.sendline(p.encode())
re.interactive()