#include #include "problem.h" #include "attack.h" #include "collision_prob.h" #include "random.h" using namespace std; static bool cacheinit = 0; static problem Ecached; static vector Pcached; static vector> publist; static vector> seclist; static bigfloat numinputs; static bigfloat numoutputs; int attack_handle(const problem &E,const vector &P,const attack &A,const vector &Q) { bigint maxcost = 1073741824; bigint maxnonbatchcost = 1024; bigint trialfactor = 1000; bigint probfactor = 10000; selection_constrain(attack_selection,"maxcost",maxcost,maxcost); selection_constrain(attack_selection,"maxnonbatchcost",maxnonbatchcost,maxnonbatchcost); selection_constrain(attack_selection,"trialfactor",trialfactor,trialfactor); selection_constrain(attack_selection,"probfactor",probfactor,probfactor); cout << "circuitprob"; cout << " problem=" << E.name; for (bigint j = 0;j < P.size();++j) cout << (j ? ',' : ' ') << E.paramnames.at(j) << "=" << P.at(j); cout << " attack="; cout << A.name; for (bigint j = 0;j < Q.size();++j) cout << (j ? ',' : ' ') << A.paramnames.at(j) << "=" << Q.at(j); bigint predictedcost = A.cost(P,Q); cout << " cost " << predictedcost; if (predictedcost > maxcost) { cout << " skipping\n" << flush; return 1; } bigfloat predictedprob = A.prob(P,Q); cout << " prob " << predictedprob; if (probfactor) if (ceil_as_bigint(predictedprob*bigfloat(probfactor)) <= 1) { cout << " skipping\n" << flush; return 1; } bit::clear_all(); if (cacheinit) { if (E.psgen != Ecached.psgen) cacheinit = 0; if (E.paramnames != Ecached.paramnames) cacheinit = 0; if (P != Pcached) cacheinit = 0; } if (!cacheinit) { Ecached = E; Pcached = P; publist.clear(); seclist.clear(); numinputs = E.numinputs(P); numoutputs = E.numoutputs(P); cacheinit = 1; } bigfloat predictedprob2 = collision_lastmatch_prob(predictedprob*numinputs,numinputs,numoutputs); cout << " prob2 " << predictedprob2; vector Pbigint; for (bigint j = 0;j < P.size();++j) Pbigint.push_back((bigint) (P.at(j))); vector Qbigint; for (bigint j = 0;j < Q.size();++j) Qbigint.push_back((bigint) (Q.at(j))); bigint bigtrials; // predicted average successes: trials*predictedprob2 // predicted deviation: sqrt(trials*predictedprob2*(1-predictedprob2)) // want deviation/successes <= X // i.e. trials*predictedprob2*(1-predictedprob2) <= X^2 trials^2 predictedprob2^2 // i.e. 1-predictedprob2 <= X^2 trials predictedprob2 if (predictedprob2 > 0.5) bigtrials = trialfactor; else { bigfloat floattrials = trialfactor*(1-predictedprob2)/predictedprob2; bigtrials = ceil_as_bigint(floattrials); } if (bigtrials < 1) bigtrials = 1; if (bigtrials > "1000000000000000000") bigtrials = "1000000000000000000"; bigint trials = bigtrials; while (publist.size() < trials) { random_seed(publist.size()); auto ps = E.psgen(P); publist.push_back(ps.first); seclist.push_back(ps.second); } bool checknonbatch = (predictedcost < maxnonbatchcost); bigint nonbatchsuccesses = 0; bigint successes = 0; if (checknonbatch) { for (bigint t = 0;t < trials;++t) { auto pub = publist.at(t); auto sec = seclist.at(t); vector pubbit; for (bigint j = 0;j < pub.size();++j) pubbit.push_back(bit(pub.at(j))); random_seed(); vector attackoutput = A.circuit(pubbit,Pbigint,Qbigint); bool success = 1; for (bigint j = 0;j < sec.size();++j) if (sec.at(j) != attackoutput.at(j).value()) success = 0; nonbatchsuccesses += success; } } if (1) { // always do batch for (bigint batch = 0;batch < trials;batch += bit_slicing) { bigint jbound = publist.at(batch).size(); vector pubbit; for (bigint j = 0;j < jbound;++j) { bitset pubj; for (bigint t = 0;t < bit_slicing && batch+t < trials;++t) pubj[t] = publist.at(batch+t).at(j); pubbit.push_back(bit(pubj)); } random_seed(); vector attackoutput = A.circuit(pubbit,Pbigint,Qbigint); for (bigint t = 0;t < bit_slicing && batch+t < trials;++t) { auto sec = seclist.at(batch+t); bool success = 1; for (bigint j = 0;j < sec.size();++j) if (sec.at(j) != attackoutput.at(j).value_vector()[t]) success = 0; successes += success; } } if (checknonbatch) assert(successes == nonbatchsuccesses); } bigfloat observed = bigfloat(successes)/bigfloat(trials); bigfloat ratio2 = observed/predictedprob2; cout << " trials " << trials; cout << " slicedops " << bit::ops(); cout << " succ " << observed; cout << " ratio2 " << ratio2; if (ratio2 > 1.1) cout << " ALERT"; if (ratio2 < 0.9) cout << " ALERT"; cout << '\n' << flush; return 1; }