-rw-r--r-- 3873 cryptattacktester-20231020/sorting.cpp raw
#include "bit_matrix.h"
#include "index.h"
#include "sorting.h"
using namespace std;
void sorting(vector<vector<bit>> &m)
{
	bigint n = m.size();
	bigint n_tmp = n;
	bigint t = 0;
	while (n_tmp > 0)
	{
		n_tmp >>= 1;
		t += 1;
	}
	if (t > 0 && n == (bigint(1) << (t-1)))
		t -= 1;
	for (bigint j = t-1; j >= 0; j--)
	{
		bigint p = 1 << j;
		bigint q = 1 << (t-1);
		bigint r = 0;
		bigint d = p;
		while (true)
		{
			for (bigint i = 0; i < n-d; i++)
			{
				if ((i & p) == r)
				{
					bit c = bit_vector_gt(m.at(i), m.at(i+d));
					bit_vector_cswap(c, m.at(i), m.at(i+d));
				}
			}
			if (q != p)
			{
				d = q - p;
				q = q / 2;
				r = p;
			}
			else
				break;
		}
	}
}
void sorting(vector<bit> &L_01,
                    vector<vector<bit>> &L_sum,
                    vector<vector<vector<bit>>> &L_set)
{
	bigint n = L_01.size();
	bigint n_tmp = n;
	bigint t = 0;
	while (n_tmp > 0)
	{
		n_tmp >>= 1;
		t += 1;
	}
	if (t > 0 && n == (bigint(1) << (t-1)))
		t -= 1;
	for (bigint j = t-1; j >= 0; j--)
	{
		bigint p = 1 << j;
		bigint q = 1 << (t-1);
		bigint r = 0;
		bigint d = p;
		while (true)
		{
			for (bigint i = 0; i < n-d; i++)
			{
				if ((i & p) == r)
				{
					vector<bit> v0 = L_sum[i];
					vector<bit> v1 = L_sum[i+d];
					v0.push_back(L_01.at(i));
					v1.push_back(L_01.at(i+d));
					bit c = bit_vector_gt_rev(v0, v1);
					c.cswap(L_01[i], L_01[i+d]);
					bit_vector_cswap(c, L_sum[i], L_sum[i+d]);
					bit_matrix_cswap(c, L_set[i], L_set[i+d]);
				}
			}
			if (q != p)
			{
				d = q - p;
				q = q / 2;
				r = p;
			}
			else
				break;
		}
	}
}
void sorting(vector<bit> &L_01,
          vector<vector<bit>> &L_sum,
          vector<vector<vector<bit>>> &L_set,
          bigint L0)
{
	assert (L_01.size() == L_sum.size());
	assert (L_sum.size() == L_set.size());
	bigint n = L_01.size();
	bigint n_tmp = n;
	bigint t = 0;
	while (n_tmp > 0)
	{
		n_tmp >>= 1;
		t += 1;
	}
	if (t > 0 && n == (bigint(1) << (t-1)))
		t -= 1;
	for (bigint j = t-1; j >= 0; j--)
	{
		bigint p = 1 << j;
		bigint q = 1 << (t-1);
		bigint r = 0;
		bigint d = p;
		while (true)
		{
			for (bigint i = 0; i < n-d; i++)
			{
				if ((i & p) == r)
				{
					vector<bit> v0 = bit_vector_extract(L_sum[i+0], 0, L0);
					vector<bit> v1 = bit_vector_extract(L_sum[i+d], 0, L0);
					v0.push_back(L_01.at(i+0));
					v1.push_back(L_01.at(i+d));
					bit c = bit_vector_gt_rev(v0, v1);
					c.cswap(L_01[i], L_01[i+d]);
					bit_vector_cswap(c, L_sum[i], L_sum[i+d]);
					bit_matrix_cswap(c, L_set[i], L_set[i+d]);
				}
			}
			if (q != p)
			{
				d = q - p;
				q = q / 2;
				r = p;
			}
			else
				break;
		}
	}
}
void sorting(vector<bit> &L_01,
          vector<vector<bit>> &L_sum,
          vector<vector<vector<bit>>> &L_set,
          vector<bit> &L_valid)
{
	assert (L_01.size() == L_sum.size());
	assert (L_sum.size() == L_set.size());
	assert (L_set.size() == L_valid.size());
	bigint n = L_01.size();
	bigint n_tmp = n;
	bigint t = 0;
	while (n_tmp > 0)
	{
		n_tmp >>= 1;
		t += 1;
	}
	if (t > 0 && n == (bigint(1) << (t-1)))
		t -= 1;
	for (bigint j = t-1; j >= 0; j--)
	{
		bigint p = 1 << j;
		bigint q = 1 << (t-1);
		bigint r = 0;
		bigint d = p;
		while (true)
		{
			for (bigint i = 0; i < n-d; i++)
			{
				if ((i & p) == r)
				{
					vector<bit> v0 = L_sum[i+0];
					vector<bit> v1 = L_sum[i+d];
					v0.push_back(L_01.at(i+0));
					v1.push_back(L_01.at(i+d));
					v0.insert(v0.begin(), L_valid.at(i+0));
					v1.insert(v1.begin(), L_valid.at(i+d));
					bit c = bit_vector_gt_rev(v0, v1);
					c.cswap(L_01[i], L_01[i+d]);
					bit_vector_cswap(c, L_sum[i], L_sum[i+d]);
					bit_matrix_cswap(c, L_set[i], L_set[i+d]);
					if (L_valid.size())
						c.cswap(L_valid[i], L_valid[i+d]);
				}
			}
			if (q != p)
			{
				d = q - p;
				q = q / 2;
				r = p;
			}
			else
				break;
		}
	}
}