-rw-r--r-- 6554 cryptattacktester-20230614/isd2_cost.cpp raw
#include <cassert>
#include "ram_cost.h"
#include "bit_cost.h"
#include "bit_vector_cost.h"
#include "bit_matrix_cost.h"
#include "subset_cost.h"
#include "index_cost.h"
#include "sorting_cost.h"
#include "parity_cost.h"
#include "isd2_cost.h"
using namespace std;
bigint isd2_cost(const vector<bigint> ¶ms,const vector<bigint> &attackparams)
{
bigint N = params.at(0);
bigint K = params.at(1);
bigint W = params.at(2);
bigint pos = 0;
bigint ITERS = attackparams.at(pos++);
bigint RESET = attackparams.at(pos++);
bigint X = attackparams.at(pos++);
bigint YX = attackparams.at(pos++); auto Y = X+YX;
bigint PIJ = attackparams.at(pos++);
bigint PI = attackparams.at(pos++);
bigint L0 = attackparams.at(pos++);
bigint L1 = attackparams.at(pos++);
bigint CHECKPI = attackparams.at(pos++);
bigint CHECKSUM = attackparams.at(pos++);
bigint D = attackparams.at(pos++);
bigint Z = attackparams.at(pos++);
bigint QU0 = attackparams.at(pos++);
bigint QF0 = attackparams.at(pos++); auto PE0 = QF0*QU0;
bigint WI0 = attackparams.at(pos++);
bigint QU1 = attackparams.at(pos++);
bigint QF1 = attackparams.at(pos++); auto PE1 = QF1*QU1;
bigint WI1 = attackparams.at(pos++);
bigint FW = attackparams.at(pos++);
bigint fwcost = 0;
if (FW) {
fwcost = parity_known_cost(N,K);
--K;
}
bigint L = L0+L1;
bigint R = N - K;
bigint KK = K + L;
bigint RR = N - KK;
bigint left = (KK-Z)/2;
bigint right = KK-Z-left;
bigint idx_bits = nbits(right-1);
bigint result = 0;
bigint listsize0 = binomial(left,PIJ);
bigint listsize1 = binomial(right,PIJ);
bigint listsize = listsize0+listsize1;
result += 2*sorting_cost(listsize,L0+1,L1+PIJ*idx_bits); // sorting(L_01, L_sum, L_set, L0);
WI0 = min(WI0,bigint(listsize-1));
bigint pool = (2*listsize-WI0-1)*WI0/2;
bigint persum = 0;
persum += 1; // bit check = L_01.at(i) ^ L_01.at(i+offset);
persum += 1+bit_vector_compare_cost(L0); // check = check.andn(bit_vector_compare(bit_vector_extract(L_sum.at(i+0), 0, L0), bit_vector_extract(L_sum.at(i+offset), 0, L0)));
persum += bit_queue1_insert_cost(QU0); // bit_queue1_insert(queue_valid, check);
persum += L1; // v = bit_vector_xor(v0, v1);
persum += QU0*L1*bit_mux_cost; // bit_vector_queue_insert(queue_sum, v, check);
persum += QU0*2*PIJ*idx_bits*bit_mux_cost; // bit_matrix_queue_insert(queue_set, set, check);
result += 2*pool*persum;
bigint queue_clears = (pool+PE0-1)/PE0;
bigint rootlistsize = 2*queue_clears*QU0;
WI1 = min(WI1,bigint(rootlistsize-1));
result += sorting_cost(rootlistsize,L1+2,2*PIJ*idx_bits); // sorting(L_root_01, L_root_sum, L_root_set, L_root_valid);
bigint rootpool = (2*rootlistsize-WI1-1)*WI1/2;
bigint perrootsum = 0;
perrootsum += 1; // check = L_root_01.at(i) ^ L_root_01.at(i+1);
perrootsum += 2; // check &= L_root_valid.at(i) & L_root_valid.at(i+1);
perrootsum += 1+bit_vector_compare_cost(L1); // check.andn(bit_vector_compare(L_root_sum.at(i+0), L_root_sum.at(i+offset)));
if (CHECKPI)
perrootsum += 2*(1+set_size_check_cost(2*PIJ,idx_bits,PI)); // check &= set_size_check(set_check, PI);
perrootsum += bit_queue1_insert_cost(QU1); // bit_queue1_insert(queue_valid, check);
perrootsum += QU1*4*PIJ*idx_bits*bit_mux_cost; // bit_matrix_queue_insert(queue_set, set, check);
result += rootpool*perrootsum;
bigint postrootqueue = 0;
postrootqueue += 2*(R-L+bit_matrix_sum_of_cols_cost(R-L,left,PIJ)); // for b=0, b=2: bit_vector_ixor(sum, bit_matrix_sum_of_cols(Hs2.at(b & 1), set_p));
postrootqueue += 2*(R-L+bit_matrix_sum_of_cols_cost(R-L,right,PIJ)); // for b=1, b=3: bit_vector_ixor(sum, bit_matrix_sum_of_cols(Hs2.at(b & 1), set_p));
postrootqueue += bit_vector_hamming_weight_cost(R-L); // bit_vector_hamming_weight(sum);
if (CHECKSUM == 0) {
postrootqueue += 2+bit_vector_integer_compare_cost(nbits(R-L),nbits(W-PI*2)); // check_w &= queue_valid.at(j).andn(bit_vector_integer_compare(w_sum, tp4));
} else {
postrootqueue += 2*set_size_cost(2*PIJ,idx_bits); // weight_list.push_back(set_size(set));
postrootqueue += bit_vector_add_cost(nbits(2*PIJ),nbits(2*PIJ)); // bit_vector_add(w_tmp, weight_list.at(0), weight_list.at(1));
postrootqueue += bit_vector_add_cost(nbits(R-L),nbits(4*PIJ)); // bit_vector_add(w_final, weight_list.at(2), w_tmp);
postrootqueue += 2+bit_vector_integer_compare_cost(nbits(R-L+4*PIJ),nbits(W)); // check_w &= queue_valid.at(j).andn(bit_vector_integer_compare(w_final, tp4));
}
postrootqueue += (R-L)*bit_mux_cost; // bit_vector_mux(s_ret, sum, check_w);
postrootqueue += 4*PIJ*idx_bits*bit_mux_cost; // bit_matrix_mux(set_ret, queue_set.at(j), check_w);
postrootqueue += N*nbits(N-1)*bit_mux_cost; // bit_matrix_mux(map_ret, column_map, check_w);
bigint root_queue_clears = (rootpool+PE1-1)/PE1;
result += root_queue_clears*QU1*postrootqueue;
result *= D;
result += subset_cost(left,PIJ,L); // subset(L0_sum, L0_set, Hs01.at(0).size(), PIJ, idx_bits, zz, Hs01.at(0));
result += subset_cost(right,PIJ,L); // subset(L1_sum[0], L1_set, Hs01.at(1).size(), PIJ, idx_bits, zz, Hs01.at(1));
result += listsize1*L; // bit_vector_xor(L1_sum[0].at(i), s01));
result += (D-1)*2*listsize1; // L1_sum[t].at(i).at(flip_idx) = ~L1_sum[t].at(i).at(flip_idx);
bigint column_swaps = column_swaps_cost(N,K,L,X,Y); // column_swaps(s, H, column_map, N, K, L, X, Y);
result += column_swaps;
result += 1; // alwayssystematic &= swapssucceeded;
result *= ITERS;
bigint perresetexceptfirst = 0;
perresetexceptfirst += bit_matrix_column_randompermutation_cost(N,K);
if (FW) perresetexceptfirst += 1; // alwayssystematic &= initial_alwayssystematic;
bigint perreset = perresetexceptfirst;
perreset -= column_swaps; // skipped on reset
perreset -= 1; // skipped on reset
perreset += 2*L*(N-K-1)*(N+1); // bit_matrix_randomize_rows
result += perreset*(ITERS/RESET);
result -= perresetexceptfirst; // skipped on iter == 0
result += indices_to_vector_cost(left,PIJ); // indices_to_vector(indices0, (KK-Z)/2);
result += indices_to_vector_cost(left,PIJ); // indices_to_vector(indices1, (KK-Z)/2);
result += left; // bit_vector_xor(v0, v1);
result += indices_to_vector_cost(right,PIJ); // indices_to_vector(indices2, (KK-Z+1)/2);
result += indices_to_vector_cost(right,PIJ); // indices_to_vector(indices3, (KK-Z+1)/2);
result += right; // bit_vector_xor(v2, v3);
result += N*ram_write_cost(N,nbits(N-1),1); // ram_write(e_ret, map_ret.at(i), e.at(i));
result += fwcost;
return result;
}