-rw-r--r-- 6554 cryptattacktester-20230614/isd2_cost.cpp raw
#include <cassert> #include "ram_cost.h" #include "bit_cost.h" #include "bit_vector_cost.h" #include "bit_matrix_cost.h" #include "subset_cost.h" #include "index_cost.h" #include "sorting_cost.h" #include "parity_cost.h" #include "isd2_cost.h" using namespace std; bigint isd2_cost(const vector<bigint> &params,const vector<bigint> &attackparams) { bigint N = params.at(0); bigint K = params.at(1); bigint W = params.at(2); bigint pos = 0; bigint ITERS = attackparams.at(pos++); bigint RESET = attackparams.at(pos++); bigint X = attackparams.at(pos++); bigint YX = attackparams.at(pos++); auto Y = X+YX; bigint PIJ = attackparams.at(pos++); bigint PI = attackparams.at(pos++); bigint L0 = attackparams.at(pos++); bigint L1 = attackparams.at(pos++); bigint CHECKPI = attackparams.at(pos++); bigint CHECKSUM = attackparams.at(pos++); bigint D = attackparams.at(pos++); bigint Z = attackparams.at(pos++); bigint QU0 = attackparams.at(pos++); bigint QF0 = attackparams.at(pos++); auto PE0 = QF0*QU0; bigint WI0 = attackparams.at(pos++); bigint QU1 = attackparams.at(pos++); bigint QF1 = attackparams.at(pos++); auto PE1 = QF1*QU1; bigint WI1 = attackparams.at(pos++); bigint FW = attackparams.at(pos++); bigint fwcost = 0; if (FW) { fwcost = parity_known_cost(N,K); --K; } bigint L = L0+L1; bigint R = N - K; bigint KK = K + L; bigint RR = N - KK; bigint left = (KK-Z)/2; bigint right = KK-Z-left; bigint idx_bits = nbits(right-1); bigint result = 0; bigint listsize0 = binomial(left,PIJ); bigint listsize1 = binomial(right,PIJ); bigint listsize = listsize0+listsize1; result += 2*sorting_cost(listsize,L0+1,L1+PIJ*idx_bits); // sorting(L_01, L_sum, L_set, L0); WI0 = min(WI0,bigint(listsize-1)); bigint pool = (2*listsize-WI0-1)*WI0/2; bigint persum = 0; persum += 1; // bit check = L_01.at(i) ^ L_01.at(i+offset); persum += 1+bit_vector_compare_cost(L0); // check = check.andn(bit_vector_compare(bit_vector_extract(L_sum.at(i+0), 0, L0), bit_vector_extract(L_sum.at(i+offset), 0, L0))); persum += bit_queue1_insert_cost(QU0); // bit_queue1_insert(queue_valid, check); persum += L1; // v = bit_vector_xor(v0, v1); persum += QU0*L1*bit_mux_cost; // bit_vector_queue_insert(queue_sum, v, check); persum += QU0*2*PIJ*idx_bits*bit_mux_cost; // bit_matrix_queue_insert(queue_set, set, check); result += 2*pool*persum; bigint queue_clears = (pool+PE0-1)/PE0; bigint rootlistsize = 2*queue_clears*QU0; WI1 = min(WI1,bigint(rootlistsize-1)); result += sorting_cost(rootlistsize,L1+2,2*PIJ*idx_bits); // sorting(L_root_01, L_root_sum, L_root_set, L_root_valid); bigint rootpool = (2*rootlistsize-WI1-1)*WI1/2; bigint perrootsum = 0; perrootsum += 1; // check = L_root_01.at(i) ^ L_root_01.at(i+1); perrootsum += 2; // check &= L_root_valid.at(i) & L_root_valid.at(i+1); perrootsum += 1+bit_vector_compare_cost(L1); // check.andn(bit_vector_compare(L_root_sum.at(i+0), L_root_sum.at(i+offset))); if (CHECKPI) perrootsum += 2*(1+set_size_check_cost(2*PIJ,idx_bits,PI)); // check &= set_size_check(set_check, PI); perrootsum += bit_queue1_insert_cost(QU1); // bit_queue1_insert(queue_valid, check); perrootsum += QU1*4*PIJ*idx_bits*bit_mux_cost; // bit_matrix_queue_insert(queue_set, set, check); result += rootpool*perrootsum; bigint postrootqueue = 0; postrootqueue += 2*(R-L+bit_matrix_sum_of_cols_cost(R-L,left,PIJ)); // for b=0, b=2: bit_vector_ixor(sum, bit_matrix_sum_of_cols(Hs2.at(b & 1), set_p)); postrootqueue += 2*(R-L+bit_matrix_sum_of_cols_cost(R-L,right,PIJ)); // for b=1, b=3: bit_vector_ixor(sum, bit_matrix_sum_of_cols(Hs2.at(b & 1), set_p)); postrootqueue += bit_vector_hamming_weight_cost(R-L); // bit_vector_hamming_weight(sum); if (CHECKSUM == 0) { postrootqueue += 2+bit_vector_integer_compare_cost(nbits(R-L),nbits(W-PI*2)); // check_w &= queue_valid.at(j).andn(bit_vector_integer_compare(w_sum, tp4)); } else { postrootqueue += 2*set_size_cost(2*PIJ,idx_bits); // weight_list.push_back(set_size(set)); postrootqueue += bit_vector_add_cost(nbits(2*PIJ),nbits(2*PIJ)); // bit_vector_add(w_tmp, weight_list.at(0), weight_list.at(1)); postrootqueue += bit_vector_add_cost(nbits(R-L),nbits(4*PIJ)); // bit_vector_add(w_final, weight_list.at(2), w_tmp); postrootqueue += 2+bit_vector_integer_compare_cost(nbits(R-L+4*PIJ),nbits(W)); // check_w &= queue_valid.at(j).andn(bit_vector_integer_compare(w_final, tp4)); } postrootqueue += (R-L)*bit_mux_cost; // bit_vector_mux(s_ret, sum, check_w); postrootqueue += 4*PIJ*idx_bits*bit_mux_cost; // bit_matrix_mux(set_ret, queue_set.at(j), check_w); postrootqueue += N*nbits(N-1)*bit_mux_cost; // bit_matrix_mux(map_ret, column_map, check_w); bigint root_queue_clears = (rootpool+PE1-1)/PE1; result += root_queue_clears*QU1*postrootqueue; result *= D; result += subset_cost(left,PIJ,L); // subset(L0_sum, L0_set, Hs01.at(0).size(), PIJ, idx_bits, zz, Hs01.at(0)); result += subset_cost(right,PIJ,L); // subset(L1_sum[0], L1_set, Hs01.at(1).size(), PIJ, idx_bits, zz, Hs01.at(1)); result += listsize1*L; // bit_vector_xor(L1_sum[0].at(i), s01)); result += (D-1)*2*listsize1; // L1_sum[t].at(i).at(flip_idx) = ~L1_sum[t].at(i).at(flip_idx); bigint column_swaps = column_swaps_cost(N,K,L,X,Y); // column_swaps(s, H, column_map, N, K, L, X, Y); result += column_swaps; result += 1; // alwayssystematic &= swapssucceeded; result *= ITERS; bigint perresetexceptfirst = 0; perresetexceptfirst += bit_matrix_column_randompermutation_cost(N,K); if (FW) perresetexceptfirst += 1; // alwayssystematic &= initial_alwayssystematic; bigint perreset = perresetexceptfirst; perreset -= column_swaps; // skipped on reset perreset -= 1; // skipped on reset perreset += 2*L*(N-K-1)*(N+1); // bit_matrix_randomize_rows result += perreset*(ITERS/RESET); result -= perresetexceptfirst; // skipped on iter == 0 result += indices_to_vector_cost(left,PIJ); // indices_to_vector(indices0, (KK-Z)/2); result += indices_to_vector_cost(left,PIJ); // indices_to_vector(indices1, (KK-Z)/2); result += left; // bit_vector_xor(v0, v1); result += indices_to_vector_cost(right,PIJ); // indices_to_vector(indices2, (KK-Z+1)/2); result += indices_to_vector_cost(right,PIJ); // indices_to_vector(indices3, (KK-Z+1)/2); result += right; // bit_vector_xor(v2, v3); result += N*ram_write_cost(N,nbits(N-1),1); // ram_write(e_ret, map_ret.at(i), e.at(i)); result += fwcost; return result; }