-rw-r--r-- 3186 cryptattacktester-20231020/bit_matrix.cpp raw
#include <cassert>
#include "ram.h"
#include "sorting.h"
#include "bit_vector.h"
#include "index.h"
#include "permutation.h"
#include "bit_matrix.h"
#include "bit_matrix_cost.h"
using namespace std;
// transposed input matrix on left, identity matrix on right
vector<vector<bit>> bit_matrix_transpose_and_identity(const vector<vector<bit>> &m)
{
  bigint rows = m.size();
  assert(rows > 0);
  bigint cols = m.at(0).size();
  assert(cols > 0);
  bigint bigcols = cols+rows;
  vector<vector<bit>> result;
  for (bigint i = 0;i < bigcols;++i) {
    vector<bit> column;
    if (i < cols)
      for (bigint j = 0;j < rows;++j)
        column.push_back(m.at(j).at(i));
    else
      for (bigint j = 0;j < rows;++j)
        column.push_back(bit(j == i-cols));
    result.push_back(column);
  }
  return result;
}
vector<bit> bit_matrix_sum_of_cols_straightforward(vector<vector<bit>> &m, vector<vector<bit>> &indices)
{
  vector<bit> result = ram_read(m,indices.at(0));
  for (bigint j = 1;j < indices.size();++j)
    bit_vector_ixor(result,ram_read(m,indices.at(j)));
  return result;
}
vector<bit> bit_matrix_sum_of_cols_viasorting(vector<vector<bit>> &m, vector<vector<bit>> &indices)
{
	vector<bit> e(0);
	bigint idx_size = indices.at(0).size();
	vector<vector<bit>> L;
	for (bigint i = 0; i < indices.size(); i++)
	{
		bigint j = indices.size()-1-i;
		vector<bit> v_j = bit_vector_from_integer(j, idx_size, 1);
		vector<bit> idx = indices.at(i);
		vector<bit> v(idx_size);
		bit_vector_add(v, v_j, idx, bit(1));
		v.insert(v.begin(), bit(0));
		L.push_back(v);
	}
	for (bigint i = 0; i < m.size(); i++)
	{
		vector<bit> v = bit_vector_from_integer(i, idx_size);
		v.insert(v.begin(), bit(1));
		L.push_back(v);
	}
	sorting(L);
	for (bigint i = 0; i < m.size(); i++)
		e.push_back(L.at(i).at(0));
	vector<bit> ret = bit_matrix_vector_mul(m, e, 1);	
	return ret;
}
vector<bit> bit_matrix_sum_of_cols(vector<vector<bit>> &m, vector<vector<bit>> &indices)
{
  bigint cols = m.size();
  assert(cols > 0);
  bigint rows = m.at(0).size();
  bigint p = indices.size();
  bigint cost_viasorting = bit_matrix_sum_of_cols_viasorting_cost(rows,cols,p);
  bigint cost_straightforward = bit_matrix_sum_of_cols_straightforward_cost(rows,cols,p);
  if (cost_straightforward < cost_viasorting)
    return bit_matrix_sum_of_cols_straightforward(m,indices);
  return bit_matrix_sum_of_cols_viasorting(m,indices);
}
bit bit_matrix_column_randompermutation(
  vector<bit> &s,
  vector<vector<bit>> &H,
  vector<vector<bit>> &column_map
)
{
  bigint cols = H.size();
  assert(cols > 0);
  bigint rows = H.at(0).size();
  assert(s.size() == rows);
  assert(column_map.size() == cols);
  permutation pi(cols);
  vector<vector<bit>> pivots;
  H.push_back(s);
  pi.permute(H);
  pi.permute(column_map);
  bit success = bit_matrix_reduced_echelon(pivots,H,cols);
  s = H.back(); H.pop_back();
  for (bigint i = 0;i < rows;++i) {
    H.at(i) = ram_read_write(H, i, cols, pivots.at(i), H.at(i));
    column_map.at(i) = ram_read_write(column_map, i, cols, pivots.at(i), column_map.at(i));
  }
  permutation pi_N(cols,rows);
  pi_N.permute(H);
  pi_N.permute(column_map);
  
  return success;
}