-rw-r--r-- 5160 cryptattacktester-20231020/bit_matrix.h raw
#ifndef BIT_MATRIX_H #define BIT_MATRIX_H #include "bigint.h" #include "bit.h" #include "ram.h" #include "util.h" #include "random.h" #include "bit_vector.h" using namespace std; vector<vector<bit>> bit_matrix_transpose_and_identity(const vector<vector<bit>> &); bit bit_matrix_column_randompermutation(vector<bit> &,vector<vector<bit>> &,vector<vector<bit>> &); static inline const vector<vector<bit>> bit_matrix(bigint n, bigint m) { return vector<vector<bit>> (n, vector<bit>(m)); } static inline void bit_matrix_clear(vector<vector<bit>> &m) { for (bigint i = 0; i < m.size(); i++) bit_vector_clear(m.at(i)); } static inline void bit_matrix_mux(vector<vector<bit>> &dest, vector<vector<bit>> &src, bit b) { for (bigint i = 0; i < dest.size(); i++) for (bigint j = 0; j < dest.at(0).size(); j++) dest.at(i).at(j) = b.mux(dest.at(i).at(j), src.at(i).at(j)); } static inline void bit_matrix_mux(vector<vector<bit>> &dest, vector<vector<bit>> &src0, vector<vector<bit>> &src1, bit b) { for (bigint i = 0; i < dest.size(); i++) for (bigint j = 0; j < dest.at(0).size(); j++) dest.at(i).at(j) = b.mux(src0.at(i).at(j), src1.at(i).at(j)); } static inline void bit_matrix_queue_insert(vector<vector<vector<bit>>> &q, vector<vector<bit>> &m, bit b) { bigint i = q.size(); i -= 2; for (; i >= 0; i--) bit_matrix_mux(q.at(i+1), q.at(i), b); bit_matrix_mux(q.at(0), m, b); } static inline bit bit_matrix_reduced_echelon(vector<vector<bit>> &pivots, vector<vector<bit>> &m, bigint bound) { bit t; vector<bit> pivot_bits(0); bigint ncols = m.size(); bigint nrows = m.at(0).size(); for (bigint r = 0; r < min(nrows, ncols); r++) { vector<bit> v(bound - r); for (bigint c = 0; c < bound - r; c++) v.at(c) = bit_vector_or_bits(bit_vector_extract(m.at(r+c), r, nrows)); vector<bit> idx = bit_vector_first_one(v); pivots.push_back(idx); for (bigint i = r+1; i < nrows; i++) { t = ram_read(m, r, bound, idx, idx.size(), r); for (bigint c = r; c < ncols; c++) m.at(c).at(r) ^= m.at(c).at(i).andn(t); } vector<bit> u = ram_read(m, r, bound, idx, idx.size()); pivot_bits.push_back(u.at(r)); for (bigint i = 0; i < nrows; i++) { if (i == r) continue; for (bigint c = r; c < ncols; c++) m.at(c).at(i) ^= (u.at(i) & m.at(c).at(r)); } } return bit_vector_and_bits(pivot_bits); } static inline vector<bit> bit_matrix_vector_mul(vector<vector<bit>> &m, vector<bit> &v, bool flip = 0) { assert(v.size() == m.size()); vector<bit> ret(m.at(0).size()); for (bigint i = 0; i < v.size(); i++) { vector<bit> w(m.at(0).size()); bit vi = v.at(i); for (bigint j = 0;j < w.size();++j) if (flip) w.at(j) = m.at(i).at(j).andn(vi); else w.at(j) = m.at(i).at(j) & vi; if (i == 0) ret = w; else bit_vector_ixor(ret, w); } return ret; } static inline vector<vector<bit>> bit_matrix_transpose(vector<vector<bit>> &m) { vector<vector<bit>> ret(m.at(0).size(), vector<bit> (m.size())); for (bigint i = 0; i < ret.size(); i++) for (bigint j = 0; j < ret.at(0).size(); j++) ret.at(i).at(j) = m.at(j).at(i); return ret; } static inline void bit_matrix_cswap(const bit c, vector<vector<bit>> &m0, vector<vector<bit>> &m1) { assert(m0.size() == m1.size()); assert(m0[0].size() == m1[0].size()); for (bigint i = 0; i < m0.size(); i++) for (bigint j = 0; j < m0[0].size(); j++) c.cswap(m0.at(i).at(j), m1.at(i).at(j)); } static inline void bit_matrix_randomize_rows(vector<vector<bit>> &m, vector<bit> &s, bigint L) { for (bigint r = 0; r < L; ++r) for (bigint i = 0; i < m.at(0).size(); ++i) { if (i == r) continue; bit b = random_bool(); // XXX: can also do b as a per-circuit constant, saving about half of the operations for (bigint c = 0; c < m.size(); ++c) m.at(c).at(i) ^= b & m.at(c).at(r); s.at(i) ^= b & s.at(r); } } vector<bit> bit_matrix_sum_of_cols_straightforward(vector<vector<bit>> &, vector<vector<bit>> &); vector<bit> bit_matrix_sum_of_cols_viasorting(vector<vector<bit>> &, vector<vector<bit>> &); vector<bit> bit_matrix_sum_of_cols(vector<vector<bit>> &, vector<vector<bit>> &); #endif