-rw-r--r-- 16584 cryptattacktester-20231020/isd2.cpp raw
#include <cassert> #include <vector> #include <random> #include "decoding.h" #include "bit.h" #include "ram.h" #include "util.h" #include "subset.h" #include "bit_vector.h" #include "index.h" #include "bit_matrix.h" #include "bit_cube.h" #include "column_swaps.h" #include "parity.h" #include "sorting.h" #include "isd0.h" using namespace std; /* Let the partial weights of a solution be (w0, w1, 0, w2) CHECKPI = 1: ensures that w0 = w1 = PI CHECKPI = 0: do not check w0 and w1 CHECKSUM = 1: ensures that w0 + w1 + w2 = T CHECKSUM = 0: ensures that w2 = T - 2PI PI == PIJ*2 | CHECKPI | CHECKSUM | partial weights of solutions --------------------------------------------------------- Y | 1 | 1 | (2PIJ, 2PIJ, 0, T - 4PIJ) --------------------------------------------------------- Y | 1 | 0 | (2PIJ, 2PIJ, 0, T - 4PIJ) --------------------------------------------------------- Y | 0 | 1 | (2PIJ - 2x, 2PIJ - 2y, 0, T - 4PIJ + 2x + 2y) --------------------------------------------------------- Y | 0 | 0 | (2PIJ - 2x, 2PIJ - 2y, 0, T - 4PIJ) <= can be useful if we know that there is no solution of weight < T !!! --------------------------------------------------------- N | 1 | 1 | (PI, PI, 0, T - 2PI) --------------------------------------------------------- N | 1 | 0 | (PI, PI, 0, T - 2PI) --------------------------------------------------------- N | 0 | 1 | (2PIJ - 2x, 2PIJ - 2y, 0, T - 2PIJ + 2x + 2y) <= PI is ignored in this case --------------------------------------------------------- N | 0 | 0 | (2PIJ - 2x, 2PIJ - 2y, 0, T - 2PI) <= totally useless? Since CHECKSUM = 1 is more expensive than CHECKSUM = 0, the useful cases are 1. (Y, 1, 0) 2. (N, 1, 0) 3. (Y, 0, 1) = (N, 0, 1) 4. (Y, 0, 0) */ template<class AT,class BT> static void shuffle(vector<AT> &A,vector<BT> &B) { bigint n = A.size(); assert(n == B.size()); permutation pi(n); pi.permute(A); pi.permute(B); } template<class AT,class BT,class CT> static void shuffle(vector<AT> &A,vector<BT> &B,vector<CT> &C) { bigint n = A.size(); assert(n == B.size()); assert(n == C.size()); permutation pi(n); pi.permute(A); pi.permute(B); pi.permute(C); } template<class AT,class BT,class CT,class DT> static void shuffle(vector<AT> &A,vector<BT> &B,vector<CT> &C,vector<DT> &D) { bigint n = A.size(); assert(n == B.size()); assert(n == C.size()); assert(n == D.size()); permutation pi(n); pi.permute(A); pi.permute(B); pi.permute(C); pi.permute(D); } vector<bit> isd2( const vector<bit> &bits, const vector<bigint> &params, const vector<bigint> &attackparams ) { bigint N = params.at(0); bigint K_orig = params.at(1); bigint T = params.at(2); bigint pos = 0; bigint ITERS = attackparams.at(pos++); bigint RESET = attackparams.at(pos++); bigint X = attackparams.at(pos++); bigint YX = attackparams.at(pos++); auto Y = X+YX; bigint PIJ = attackparams.at(pos++); bigint PI = attackparams.at(pos++); bigint L0 = attackparams.at(pos++); bigint L1 = attackparams.at(pos++); bigint CHECKPI = attackparams.at(pos++); bigint CHECKSUM = attackparams.at(pos++); bigint D = attackparams.at(pos++); bigint Z = attackparams.at(pos++); bigint QU0 = attackparams.at(pos++); bigint QF0 = attackparams.at(pos++); auto PE0 = QF0*QU0; bigint WI0 = attackparams.at(pos++); bigint QU1 = attackparams.at(pos++); bigint QF1 = attackparams.at(pos++); auto PE1 = QF1*QU1; bigint WI1 = attackparams.at(pos++); bigint FW = attackparams.at(pos++); bigint L = L0 + L1; assert(PI <= 2*PIJ); assert(PI%2 == 0); assert(D >= 1); assert(!((D-1)>>L0)); // D <= 2^L0 auto inputs = decoding_deserialize(bits,params); auto pk = inputs.first; auto s = inputs.second; vector<vector<bit>> H = bit_matrix_transpose_and_identity(pk); vector<vector<bit>> column_map; for (bigint i = 0; i < N; i++) column_map.push_back(bit_vector_from_integer(i, nbits(N-1))); bit alwayssystematic = 1; bigint K = K_orig; if (FW) { alwayssystematic = parity_known(s,H,column_map,bit(T.bit(0))); K -= 1; } vector<vector<bit>> initial_H = H; vector<bit> initial_s = s; vector<vector<bit>> initial_column_map = column_map; bit initial_alwayssystematic = alwayssystematic; bigint R = N - K; bigint KK = K + L; const bigint idx_bits = nbits((KK-Z+1)/2-1); vector<bit> s_ret(N-K-L); vector<vector<bit>> set_ret = bit_matrix(PIJ*4, idx_bits); vector<vector<bit>> map_ret = bit_matrix(N, nbits(N-1)); bigint untilreset = 0; for (bigint iter = 0; iter < ITERS; iter++) { // if alwayssystematic: H.at(i).at(j) == (i-KK == j-L) for KK <= i < N, 0 <= j < R if (untilreset > 0) { alwayssystematic &= column_swaps(s, H, column_map, N, K, L, X, Y); } else { untilreset = RESET; H = initial_H; s = initial_s; column_map = initial_column_map; if (iter == 0) alwayssystematic = initial_alwayssystematic; else { alwayssystematic = bit_matrix_column_randompermutation(s,H,column_map); if (FW) alwayssystematic &= initial_alwayssystematic; } bit_matrix_randomize_rows(H, s, L); } --untilreset; // partitioning s and H vector<bit> s01 = bit_vector_extract(s, 0, L); vector<bit> s2 = bit_vector_extract(s, L, R); vector<vector<vector<bit>>> Hs01(2); vector<vector<vector<bit>>> Hs2(2); for (bigint i = 0; i < KK-Z; i++) { Hs01.at( (i < (KK-Z)/2) ? 0 : 1 ).push_back(bit_vector_extract(H.at(i), 0, L)); Hs2.at( (i < (KK-Z)/2) ? 0 : 1 ).push_back(bit_vector_extract(H.at(i), L, R)); } // search for solution bigint flip_idx; vector<bigint> q_gray(0); vector<vector<bit>> L0_sum(0), L1_sum[2]; vector<vector<vector<bit>>> L0_set(0), L1_set(0); bigint lens[2] = {0,0}; for (bigint d = 0; d < D; d++) // randomizing search tree { vector<bit> L_root_01(0); vector<bit> L_root_valid(0); vector<vector<bit>> L_root_sum(0); vector<vector<vector<bit>>> L_root_set(0); for (bigint t = 0; t < 2; t++) { vector<bit> L_01(0); vector<vector<bit>> L_sum(0); vector<vector<vector<bit>>> L_set(0); vector<bit> zz(L); if (d == 0 and t == 0) { subset(L0_sum, L0_set, Hs01.at(0).size(), PIJ, idx_bits, zz, Hs01.at(0)); subset(L1_sum[0], L1_set, Hs01.at(1).size(), PIJ, idx_bits, zz, Hs01.at(1)); lens[0] = L0_sum.size(); lens[1] = L1_sum[0].size(); for (bigint i = 0; i < lens[1]; i++) L1_sum[1].push_back(bit_vector_xor(L1_sum[0].at(i), s01)); } if (d > 0) // making use of gray code { if (t == 0) flip_idx = gray_idx(q_gray); for (bigint i = 0; i < lens[1]; i++) L1_sum[t].at(i).at(flip_idx) = ~L1_sum[t].at(i).at(flip_idx); } for (bigint i = 0; i < lens[0]; i++) { L_01.push_back(bit(0)); L_sum.push_back(L0_sum.at(i)); L_set.push_back(L0_set.at(i)); } for (bigint i = 0; i < lens[1]; i++) { L_01.push_back(bit(1)); L_sum.push_back(L1_sum[t].at(i)); L_set.push_back(L1_set.at(i)); } shuffle(L_01, L_sum, L_set); sorting(L_01, L_sum, L_set, L0); // vector<bit> todo_check; vector<vector<bit>> todo_sum; vector<vector<vector<bit>>> todo_set; for (bigint i = 0; i < L_sum.size()-1; i++) { for (bigint offset = 1;offset <= WI0;++offset) { if (i+offset >= L_sum.size()) continue; bit check = L_01.at(i) ^ L_01.at(i+offset); check = check.andn(bit_vector_compare(bit_vector_extract(L_sum.at(i+0), 0, L0), bit_vector_extract(L_sum.at(i+offset), 0, L0))); vector<bit> v0 = bit_vector_extract(L_sum.at(i+0), L0, L); vector<bit> v1 = bit_vector_extract(L_sum.at(i+offset), L0, L); vector<bit> v = bit_vector_xor(v0, v1); vector<vector<bit>> set(0); for (bigint j = 0; j < PIJ; j++) set.push_back(L_set.at(i+0).at(j)); for (bigint j = 0; j < PIJ; j++) set.push_back(L_set.at(i+offset).at(j)); todo_check.push_back(check); todo_sum.push_back(v); todo_set.push_back(set); } } shuffle(todo_check,todo_sum,todo_set); vector<bit> queue_valid(QU0); vector<vector<bit>> queue_sum(QU0, vector<bit>(L1)); vector<vector<vector<bit>>> queue_set = bit_cube(QU0, PIJ*2, idx_bits); bigint timer = 0; for (bigint z = 0;z < todo_check.size();++z) { timer = (timer + 1) % PE0; if (z == todo_check.size()-1) timer = 0; auto check = todo_check.at(z); auto sum = todo_sum.at(z); auto set = todo_set.at(z); bit_queue1_insert(queue_valid, check); bit_vector_queue_insert(queue_sum, sum, check); bit_matrix_queue_insert(queue_set, set, check); // processing elements in the queue if (timer == 0) // { for (bigint j = 0; j < QU0; j++) { L_root_01.push_back(bit(t)); L_root_valid.push_back(queue_valid.at(j)); L_root_sum.push_back(queue_sum.at(j)); L_root_set.push_back(queue_set.at(j)); // clear the queue elements queue_valid.at(j) = bit(0); bit_vector_clear(queue_sum.at(j)); bit_matrix_clear(queue_set.at(j)); } } } } // t shuffle(L_root_01, L_root_sum, L_root_set, L_root_valid); sorting(L_root_01, L_root_sum, L_root_set, L_root_valid); vector<bit> todo_check; vector<vector<vector<bit>>> todo_set; for (bigint i = 0; i < L_root_sum.size()-1; i++) { for (bigint offset = 1;offset <= WI1;++offset) { if (i+offset >= L_root_sum.size()) continue; bit check; check = L_root_01.at(i) ^ L_root_01.at(i+offset); check &= L_root_valid.at(i) & L_root_valid.at(i+offset); check = check.andn(bit_vector_compare(L_root_sum.at(i+0), L_root_sum.at(i+offset))); // do weight check if CHECKPI if (CHECKPI) { for (bigint k = 0; k < 2; k++) { vector<vector<bit>> set_check(0); for (bigint j = PIJ*k; j < PIJ*(k+1); j++) set_check.push_back(L_root_set.at(i+0).at(j)); for (bigint j = PIJ*k; j < PIJ*(k+1); j++) set_check.push_back(L_root_set.at(i+offset).at(j)); check &= set_size_check(set_check, PI); } } vector<vector<bit>> set(0); for (bigint j = 0; j < PIJ*2; j++) set.push_back(L_root_set.at(i+0).at(j)); for (bigint j = 0; j < PIJ*2; j++) set.push_back(L_root_set.at(i+offset).at(j)); todo_check.push_back(check); todo_set.push_back(set); } } shuffle(todo_check,todo_set); vector<bit> queue_valid(QU1); vector<vector<vector<bit>>> queue_set = bit_cube(QU1, PIJ*4, idx_bits); bigint timer = 0; for (bigint z = 0;z < todo_check.size();++z) { timer = (timer + 1) % PE1; if (z == todo_check.size()-1) timer = 0; // conditionally pushing pairs into the queue auto check = todo_check.at(z); auto set = todo_set.at(z); bit_queue1_insert(queue_valid, check); bit_matrix_queue_insert(queue_set, set, check); // processing elements in the queue if (timer == 0) { for (bigint j = 0; j < QU1; j++) { vector<bit> sum = s2; for (bigint b = 0; b < 4; b++) { vector<vector<bit>> set_p(0); for (bigint p = PIJ*b; p < PIJ*(b+1); p++) set_p.push_back(queue_set.at(j).at(p)); bit_vector_ixor(sum, bit_matrix_sum_of_cols(Hs2.at(b & 1), set_p)); } // final check const vector<bit> tp4 = bit_vector_from_integer((CHECKSUM) ? T : T-PI*2); vector<bit> w_sum = bit_vector_hamming_weight(sum); bit check_w = alwayssystematic; if (CHECKSUM == 0) // make sure that the solution has partial weights (*, *, 0, T-PI*2) check_w &= queue_valid.at(j).andn(bit_vector_integer_compare(w_sum, tp4)); else // make sure that the solution has partial weights (w0, w1, 0, w2) with w0 + w1 + w2 = T { vector<vector<bit>> weight_list(0); for (bigint b = 0; b < 2; b++) { vector<vector<bit>> set(0); for (bigint p = PIJ*(b+0); p < PIJ*(b+1); p++) set.push_back(queue_set.at(j).at(p)); for (bigint p = PIJ*(b+2); p < PIJ*(b+3); p++) set.push_back(queue_set.at(j).at(p)); weight_list.push_back(set_size(set)); } weight_list.push_back(w_sum); vector<bit> w_tmp(nbits(PIJ*4)), w_final(nbits(PIJ*4 + R-L)); bit_vector_add(w_tmp, weight_list.at(0), weight_list.at(1)); bit_vector_add(w_final, weight_list.at(2), w_tmp); check_w &= queue_valid.at(j).andn(bit_vector_integer_compare(w_final, tp4)); } // store solution bit_vector_mux(s_ret, sum, check_w); bit_matrix_mux(set_ret, queue_set.at(j), check_w); bit_matrix_mux(map_ret, column_map, check_w); // clear the queue elements queue_valid.at(j) = bit(0); bit_matrix_clear(queue_set.at(j)); } } } } } // iter vector<bit> e_ret(N); vector<bit> e(0); vector<vector<bit>> indices0; vector<vector<bit>> indices1; vector<vector<bit>> indices2; vector<vector<bit>> indices3; for (bigint i = PIJ*0; i < PIJ*1; i++) indices0.push_back(set_ret.at(i)); vector<bit> v0 = indices_to_vector(indices0, (KK-Z)/2); for (bigint i = PIJ*2; i < PIJ*3; i++) indices1.push_back(set_ret.at(i)); vector<bit> v1 = indices_to_vector(indices1, (KK-Z)/2); vector<bit> v01 = bit_vector_xor(v0, v1); for (bigint i = PIJ*1; i < PIJ*2; i++) indices2.push_back(set_ret.at(i)); vector<bit> v2 = indices_to_vector(indices2, (KK-Z+1)/2); for (bigint i = PIJ*3; i < PIJ*4; i++) indices3.push_back(set_ret.at(i)); vector<bit> v3 = indices_to_vector(indices3, (KK-Z+1)/2); vector<bit> v23 = bit_vector_xor(v2, v3); for (bigint i = 0; i < v01.size(); i++) e.push_back(v01.at(i)); for (bigint i = 0; i < v23.size(); i++) e.push_back(v23.at(i)); for (bigint i = 0; i < Z; i++) e.push_back(bit(0)); for (bigint i = 0; i < s_ret.size(); i++) e.push_back(s_ret.at(i)); assert(e.size() == N); for (bigint i = 0; i < N; i++) ram_write(e_ret, map_ret.at(i), e.at(i)); // pk has identity implicitly on left, H has it on right // so change convention for output ordering vector<bit> e_ret_swap; for (bigint i = K_orig;i < N;++i) e_ret_swap.push_back(e_ret.at(i)); for (bigint i = 0;i < K_orig;++i) e_ret_swap.push_back(e_ret.at(i)); return e_ret_swap; }