-rw-r--r-- 8888 cryptattacktester-20230614/aes128_enum.cpp raw
#include <cassert> #include "bit_vector.h" #include "aes128_enum.h" using namespace std; typedef vector<bit> byte; static byte byte_xor(byte c,byte d) { byte result; for (bigint i = 0;i < 8;++i) result.push_back(c.at(i)^d.at(i)); return result; } static vector<bit> two{0,1,0,0,0,0,0,0}; static byte xtime(byte c) { bit c0 = c.at(0); bit c1 = c.at(1); bit c2 = c.at(2); bit c3 = c.at(3); bit c4 = c.at(4); bit c5 = c.at(5); bit c6 = c.at(6); bit c7 = c.at(7); bit h0 = c7; bit h1 = c0^c7; bit h2 = c1; bit h3 = c2^c7; bit h4 = c3^c7; bit h5 = c4; bit h6 = c5; bit h7 = c6; byte result; result.push_back(h0); result.push_back(h1); result.push_back(h2); result.push_back(h3); result.push_back(h4); result.push_back(h5); result.push_back(h6); result.push_back(h7); return result; } static byte byte_sub(byte c) { bit U0 = c.at(7); bit U1 = c.at(6); bit U2 = c.at(5); bit U3 = c.at(4); bit U4 = c.at(3); bit U5 = c.at(2); bit U6 = c.at(1); bit U7 = c.at(0); bit y14 = U3 ^ U5; bit y13 = U0 ^ U6; bit y9 = U0 ^ U3; bit y8 = U0 ^ U5; bit t0 = U1 ^ U2; bit y1 = t0 ^ U7; bit y4 = y1 ^ U3; bit y12 = y13 ^ y14; bit y2 = y1 ^ U0; bit y5 = y1 ^ U6; bit y3 = y5 ^ y8; bit t1 = U4 ^ y12; bit y15 = t1 ^ U5; bit y20 = t1 ^ U1; bit y6 = y15 ^ U7; bit y10 = y15 ^ t0; bit y11 = y20 ^ y9; bit y7 = U7 ^ y11; bit y17 = y10 ^ y11; bit y19 = y10 ^ y8; bit y16 = t0 ^ y11; bit y21 = y13 ^ y16; bit y18 = U0 ^ y16; bit t2 = y12 & y15; bit t3 = y3 & y6; bit t4 = t3 ^ t2; bit t5 = y4 & U7; bit t6 = t5 ^ t2; bit t7 = y13 & y16; bit t8 = y5 & y1; bit t9 = t8 ^ t7; bit t10 = y2 & y7; bit t11 = t10 ^ t7; bit t12 = y9 & y11; bit t13 = y14 & y17; bit t14 = t13 ^ t12; bit t15 = y8 & y10; bit t16 = t15 ^ t12; bit t17 = t4 ^ y20; bit t18 = t6 ^ t16; bit t19 = t9 ^ t14; bit t20 = t11 ^ t16; bit t21 = t17 ^ t14; bit t22 = t18 ^ y19; bit t23 = t19 ^ y21; bit t24 = t20 ^ y18; bit t25 = t21 ^ t22; bit t26 = t21 & t23; bit t27 = t24 ^ t26; bit t28 = t25 & t27; bit t29 = t28 ^ t22; bit t30 = t23 ^ t24; bit t31 = t22 ^ t26; bit t32 = t31 & t30; bit t33 = t32 ^ t24; bit t34 = t23 ^ t33; bit t35 = t27 ^ t33; bit t36 = t24 & t35; bit t37 = t36 ^ t34; bit t38 = t27 ^ t36; bit t39 = t29 & t38; bit t40 = t25 ^ t39; bit t41 = t40 ^ t37; bit t42 = t29 ^ t33; bit t43 = t29 ^ t40; bit t44 = t33 ^ t37; bit t45 = t42 ^ t41; bit z0 = t44 & y15; bit z1 = t37 & y6; bit z2 = t33 & U7; bit z3 = t43 & y16; bit z4 = t40 & y1; bit z5 = t29 & y7; bit z6 = t42 & y11; bit z7 = t45 & y17; bit z8 = t41 & y10; bit z9 = t44 & y12; bit z10 = t37 & y3; bit z11 = t33 & y4; bit z12 = t43 & y13; bit z13 = t40 & y5; bit z14 = t29 & y2; bit z15 = t42 & y9; bit z16 = t45 & y14; bit z17 = t41 & y8; bit tc1 = z15 ^ z16; bit tc2 = z10 ^ tc1; bit tc3 = z9 ^ tc2; bit tc4 = z0 ^ z2; bit tc5 = z1 ^ z0; bit tc6 = z3 ^ z4; bit tc7 = z12 ^ tc4; bit tc8 = z7 ^ tc6; bit tc9 = z8 ^ tc7; bit tc10 = tc8 ^ tc9; bit tc11 = tc6 ^ tc5; bit tc12 = z3 ^ z5; bit tc13 = z13 ^ tc1; bit tc14 = tc4 ^ tc12; bit S3 = tc3 ^ tc11; bit tc16 = z6 ^ tc8; bit tc17 = z14 ^ tc10; bit tc18 = tc13 ^ tc14; bit S7 = z12.xnor(tc18); bit tc20 = z15 ^ tc16; bit tc21 = tc2 ^ z11; bit S0 = tc3 ^ tc16; bit S6 = tc10.xnor(tc18); bit S4 = tc14 ^ S3; bit S1 = S3.xnor(tc16); bit tc26 = tc17 ^ tc20; bit S2 = tc26.xnor(z17); bit S5 = tc21 ^ tc17; byte result; result.push_back(S7); result.push_back(S6); result.push_back(S5); result.push_back(S4); result.push_back(S3); result.push_back(S2); result.push_back(S1); result.push_back(S0); return result; } static vector<bit> initialroundconstant{1,0,0,0,0,0,0,0}; static vector<bit> encrypt(const vector<bit> &in,const vector<bit> &k) { vector<vector<byte>> expanded(4,vector<byte> (44)); vector<vector<byte>> state(4,vector<byte> (4)); vector<vector<byte>> newstate(4,vector<byte> (4)); byte roundconstant; bigint i; bigint j; bigint r; bigint bitpos; for (j = 0;j < 4;++j) for (i = 0;i < 4;++i) for (bitpos = 0;bitpos < 8;++bitpos) expanded.at(i).at(j).push_back(k.at((j*4+i)*8+bitpos)); roundconstant = initialroundconstant; for (j = 4;j < 44;++j) { vector<byte> temp(4); if (j % 4) for (i = 0;i < 4;++i) temp.at(i) = expanded.at(i).at(j - 1); else { for (i = 0;i < 4;++i) temp.at(i) = byte_sub(expanded.at((i + 1) % 4).at(j - 1)); temp.at(0) = byte_xor(temp.at(0),roundconstant); roundconstant = xtime(roundconstant); } for (i = 0;i < 4;++i) expanded.at(i).at(j) = byte_xor(temp.at(i),expanded.at(i).at(j - 4)); } for (j = 0;j < 4;++j) for (i = 0;i < 4;++i) for (bitpos = 0;bitpos < 8;++bitpos) state.at(i).at(j).push_back(in.at((j*4+i)*8+bitpos)); for (j = 0;j < 4;++j) for (i = 0;i < 4;++i) state.at(i).at(j) = byte_xor(state.at(i).at(j),expanded.at(i).at(j)); for (r = 0;r < 10;++r) { for (i = 0;i < 4;++i) for (j = 0;j < 4;++j) newstate.at(i).at(j) = byte_sub(state.at(i).at(j)); for (i = 0;i < 4;++i) for (j = 0;j < 4;++j) state.at(i).at(j) = newstate.at(i).at((j + i) % 4); if (r < 9) for (j = 0;j < 4;++j) { byte a0 = state.at(0).at(j); byte a1 = state.at(1).at(j); byte a2 = state.at(2).at(j); byte a3 = state.at(3).at(j); byte a01 = byte_xor(a0,a1); byte a12 = byte_xor(a1,a2); byte a23 = byte_xor(a2,a3); byte a30 = byte_xor(a3,a0); state.at(0).at(j) = byte_xor(xtime(a01),byte_xor(a1,a23)); state.at(1).at(j) = byte_xor(xtime(a12),byte_xor(a2,a30)); state.at(2).at(j) = byte_xor(xtime(a23),byte_xor(a3,a01)); state.at(3).at(j) = byte_xor(xtime(a30),byte_xor(a0,a12)); } for (i = 0;i < 4;++i) for (j = 0;j < 4;++j) state.at(i).at(j) = byte_xor(state.at(i).at(j),expanded.at(i).at(r * 4 + 4 + j)); } vector<bit> result; for (j = 0;j < 4;++j) for (i = 0;i < 4;++i) for (bitpos = 0;bitpos < 8;++bitpos) result.push_back(state.at(i).at(j).at(bitpos)); return result; } vector<bit> aes128_enum( const vector<bit> &bits, const vector<bigint> &params, const vector<bigint> &attackparams ) { bigint K = params.at(0); bigint C = params.at(1); bigint pos = 0; bigint I = attackparams.at(pos++); bigint QX = attackparams.at(pos++); bigint QUEUE_SIZE = attackparams.at(pos++); bigint QF = attackparams.at(pos++); auto PERIOD = QF*QUEUE_SIZE; vector<bit> plaintext0(128); vector<bit> ciphertext0(C); vector<bit> plaintext1(128); vector<bit> ciphertext1(C); { bigint pos = 0; for (bigint j = 0;j < 128;++j) plaintext0.at(j) = bits.at(pos++); for (bigint j = 0;j < C;++j) ciphertext0.at(j) = bits.at(pos++); for (bigint j = 0;j < 128;++j) plaintext1.at(j) = bits.at(pos++); for (bigint j = 0;j < C;++j) ciphertext1.at(j) = bits.at(pos++); assert(pos == bits.size()); } vector<bit> result(K,bit(1)); // note that queue is not used if QX == 0 vector<bit> queue_valid(QUEUE_SIZE); vector<vector<bit>> queue(QUEUE_SIZE,vector<bit>(K)); bigint timer = 0; vector<bit> guess(128); for (bigint iter = 0;iter < I;++iter) { auto guessct0 = encrypt(plaintext0,guess); bit mismatch; for (bigint j = 0;j < C;++j) mismatch |= guessct0.at(j)^ciphertext0.at(j); if (QX == 0) { auto guessct1 = encrypt(plaintext1,guess); for (bigint j = 0;j < C;++j) mismatch |= guessct1.at(j)^ciphertext1.at(j); for (bigint j = 0;j < K;++j) result.at(j) = mismatch.mux(guess.at(j),result.at(j)); } else { vector<bit> guessprefix(K); for (bigint j = 0;j < K;++j) guessprefix.at(j) = guess.at(j); bit match = ~mismatch; bit_queue1_insert(queue_valid,match); bit_vector_queue_insert(queue,guessprefix,match); ++timer; if (timer == PERIOD || iter == I-1) { timer = 0; for (bigint q = 0;q < QUEUE_SIZE;++q) { vector<bit> queueguess(128); for (bigint j = 0;j < K;++j) queueguess.at(j) = queue.at(q).at(j); auto guessct1 = encrypt(plaintext1,queueguess); bit mismatch1 = ~queue_valid.at(q); for (bigint j = 0;j < C;++j) mismatch1 |= guessct1.at(j)^ciphertext1.at(j); for (bigint j = 0;j < K;++j) result.at(j) = mismatch1.mux(queueguess.at(j),result.at(j)); queue_valid.at(q) = bit(0); } } } bit incrementing(1); for (bigint j = 0;j < K;++j) { bit old = guess.at(j); guess.at(j) = old^incrementing; incrementing &= old; } } return result; }