-rw-r--r-- 3186 cryptattacktester-20230614/bit_matrix.cpp raw
#include <cassert> #include "ram.h" #include "sorting.h" #include "bit_vector.h" #include "index.h" #include "permutation.h" #include "bit_matrix.h" #include "bit_matrix_cost.h" using namespace std; // transposed input matrix on left, identity matrix on right vector<vector<bit>> bit_matrix_transpose_and_identity(const vector<vector<bit>> &m) { bigint rows = m.size(); assert(rows > 0); bigint cols = m.at(0).size(); assert(cols > 0); bigint bigcols = cols+rows; vector<vector<bit>> result; for (bigint i = 0;i < bigcols;++i) { vector<bit> column; if (i < cols) for (bigint j = 0;j < rows;++j) column.push_back(m.at(j).at(i)); else for (bigint j = 0;j < rows;++j) column.push_back(bit(j == i-cols)); result.push_back(column); } return result; } vector<bit> bit_matrix_sum_of_cols_straightforward(vector<vector<bit>> &m, vector<vector<bit>> &indices) { vector<bit> result = ram_read(m,indices.at(0)); for (bigint j = 1;j < indices.size();++j) bit_vector_ixor(result,ram_read(m,indices.at(j))); return result; } vector<bit> bit_matrix_sum_of_cols_viasorting(vector<vector<bit>> &m, vector<vector<bit>> &indices) { vector<bit> e(0); bigint idx_size = indices.at(0).size(); vector<vector<bit>> L; for (bigint i = 0; i < indices.size(); i++) { bigint j = indices.size()-1-i; vector<bit> v_j = bit_vector_from_integer(j, idx_size, 1); vector<bit> idx = indices.at(i); vector<bit> v(idx_size); bit_vector_add(v, v_j, idx, bit(1)); v.insert(v.begin(), bit(0)); L.push_back(v); } for (bigint i = 0; i < m.size(); i++) { vector<bit> v = bit_vector_from_integer(i, idx_size); v.insert(v.begin(), bit(1)); L.push_back(v); } sorting(L); for (bigint i = 0; i < m.size(); i++) e.push_back(L.at(i).at(0)); vector<bit> ret = bit_matrix_vector_mul(m, e, 1); return ret; } vector<bit> bit_matrix_sum_of_cols(vector<vector<bit>> &m, vector<vector<bit>> &indices) { bigint cols = m.size(); assert(cols > 0); bigint rows = m.at(0).size(); bigint p = indices.size(); bigint cost_viasorting = bit_matrix_sum_of_cols_viasorting_cost(rows,cols,p); bigint cost_straightforward = bit_matrix_sum_of_cols_straightforward_cost(rows,cols,p); if (cost_straightforward < cost_viasorting) return bit_matrix_sum_of_cols_straightforward(m,indices); return bit_matrix_sum_of_cols_viasorting(m,indices); } bit bit_matrix_column_randompermutation( vector<bit> &s, vector<vector<bit>> &H, vector<vector<bit>> &column_map ) { bigint cols = H.size(); assert(cols > 0); bigint rows = H.at(0).size(); assert(s.size() == rows); assert(column_map.size() == cols); permutation pi(cols); vector<vector<bit>> pivots; H.push_back(s); pi.permute(H); pi.permute(column_map); bit success = bit_matrix_reduced_echelon(pivots,H,cols); s = H.back(); H.pop_back(); for (bigint i = 0;i < rows;++i) { H.at(i) = ram_read_write(H, i, cols, pivots.at(i), H.at(i)); column_map.at(i) = ram_read_write(column_map, i, cols, pivots.at(i), column_map.at(i)); } permutation pi_N(cols,rows); pi_N.permute(H); pi_N.permute(column_map); return success; }