-rw-r--r-- 3193 cryptattacktester-20231020/bittest.cpp raw
#include <cassert>
#include <iostream>
#include "bit.h"
using namespace std;
int main()
{
  bit a0;
  bit a1(1);
  bit b0 = 2;
  bit b1; b1 = 3;
  bigint expected = 0;
  for (auto op: bit_ops_selectors) {
    switch(op) {
      case bit_ops_not: ~a0; expected += bit_not_cost; break;
      case bit_ops_xor: a0 ^ a1; expected += bit_xor_cost; break;
      case bit_ops_and: a0 & a1; expected += bit_and_cost; break;
      case bit_ops_or: a0 | a1; expected += bit_or_cost; break;
      case bit_ops_xnor: a0.xnor(a1); expected += bit_xnor_cost; break;
      case bit_ops_andn: a0.andn(a1); expected += bit_andn_cost; break;
      case bit_ops_nand: a0.nand(a1); expected += bit_nand_cost; break;
      case bit_ops_orn: a0.orn(a1); expected += bit_orn_cost; break;
      case bit_ops_nor: a0.nor(a1); expected += bit_nor_cost; break;
      case bit_ops_mux: a0.mux(b0,b1); expected += bit_mux_cost; break;
      case bit_ops_cswap: a0.cswap(b0,b1); expected += bit_cswap_cost; break;
      default: assert(op == bit_ops_cost);
    }
    assert(expected == bit::ops());
    cout << "bittest ";
    for (auto i: bit_ops_selectors)
      cout << bit::opsname(i) << " " << bit::ops(i) << " ";
    cout << "\n";
  }
  for (bigint op = 0;op < 16;++op) {
    cout << "bittest ";
    const char *desc = "fail";
    bit result[4];
    for (bigint i = 0;i < 2;++i)
      for (bigint j = 0;j < 2;++j) {
        bit x(i);
        bit y(j);
        bit r;
        if (op == 0) { r = bit(0); desc = "bit(0)"; }
        if (op == 1) { r = x & y; desc = "x & y"; }
        if (op == 2) { r = x.andn(y); desc = "x.andn(y)"; }
        if (op == 3) { r = x; desc = "x"; }
        if (op == 4) { r = y.andn(x); desc = "y.andn(x)"; }
        if (op == 5) { r = y; desc = "y"; }
        if (op == 6) { r = x ^ y; desc = "x ^ y"; }
        if (op == 7) { r = x | y; desc = "x | y"; }
        if (op == 8) { r = x.nor(y); desc = "x.nor(y)"; }
        if (op == 9) { r = x.xnor(y); desc = "x.xnor(y)"; }
        if (op == 10) { r = ~y; desc = "~y"; }
        if (op == 11) { r = x.orn(y); desc = "x.orn(y)"; }
        if (op == 12) { r = ~x; desc = "~x"; }
        if (op == 13) { r = y.orn(x); desc = "y.orn(x)"; }
        if (op == 14) { r = x.nand(y); desc = "x.nand(y)"; }
        if (op == 15) { r = bit(1); desc = "bit(1)"; }
        result[i*2+j] = r;
        cout << " " << r.value() << " ";
      }
    cout << desc << "\n";
    assert(result[0].value() == op.bit(3));
    assert(result[1].value() == op.bit(2));
    assert(result[2].value() == op.bit(1));
    assert(result[3].value() == op.bit(0));
  }
  assert(a0.value() == 0);
  assert(a1.value() == 1);
  assert(b0.value() == 0);
  assert(b1.value() == 1);
  assert((~b0).value() == 1);
  assert((~b1).value() == 0);
  assert((a0&b0).value() == 0);
  assert((a0&b1).value() == 0);
  assert((a1&b0).value() == 0);
  assert((a1&b1).value() == 1);
  assert((a0|b0).value() == 0);
  assert((a0|b1).value() == 1);
  assert((a1|b0).value() == 1);
  assert((a1|b1).value() == 1);
  assert((a0^b0).value() == 0);
  assert((a0^b1).value() == 1);
  assert((a1^b0).value() == 1);
  assert((a1^b1).value() == 0);
  // XXX: also test bitslicing
  return 0;
}