#include #include #include #include #include #include //========================================== // 4096ビットの符号なし整数型: uint4096_t //========================================== class uint4096_t { private: static const size_t SIZE = 64; // 64 * 64 = 4096ビット std::array data; // [0] が下位64ビット、[SIZE-1] が上位64ビット public: uint4096_t() { data.fill(0ULL); } uint4096_t(uint64_t low) { data.fill(0ULL); data[0] = low; } //=== highestNonzeroBlock() === static int highestNonzeroBlock(const uint4096_t &x) { for (int i = SIZE - 1; i >= 0; i--) { if (x.data[i] != 0ULL) return i; } return -1; // all zero } //=== mul_auto() === static void mul_auto(const uint4096_t &x, const uint4096_t &y, uint4096_t &dest) { constexpr int thresholdBlocks = 32; // 例: 32ブロック=2048ビット int hiX = highestNonzeroBlock(x); int hiY = highestNonzeroBlock(y); // どちらかが 0 なら結果は 0 if (hiX < 0 || hiY < 0) { for (size_t i = 0; i < SIZE; i++) { dest.data[i] = 0ULL; } return; } if (hiX >= thresholdBlocks || hiY >= thresholdBlocks) { // Karatsuba karatsuba_4096(x, y, dest); } else { // ナイーブ mul_4096_naive(x, y, dest); } } // ビット単位の反転演算子 uint4096_t operator~() const { uint4096_t result; for (size_t i = 0; i < SIZE; ++i) { result.data[i] = ~data[i]; // 各64ビットブロックを反転 } return result; } // 最大値を返す静的メソッド static uint4096_t max_value() { uint4096_t result; for (size_t i = 0; i < SIZE; ++i) { result.data[i] = ~0ULL; // 全ビットを1に設定 } return result; } //------------------------------------------- // 加算 (ヘルパー): add_4096 //------------------------------------------- static void add_4096(const uint4096_t &a, const uint4096_t &b, uint4096_t &dest) { __uint128_t carry = 0; for (size_t i = 0; i < SIZE; i++) { __uint128_t tmp = (__uint128_t)a.data[i] + (__uint128_t)b.data[i] + carry; dest.data[i] = (uint64_t)tmp; carry = (tmp >> 64); } if (carry != 0) { throw std::runtime_error("Overflow in add_4096!"); } } //------------------------------------------- // 減算 (ヘルパー): sub_4096 (a>=b前提) //------------------------------------------- static void sub_4096(const uint4096_t &a, const uint4096_t &b, uint4096_t &dest) { int64_t borrow = 0; for (size_t i = 0; i < SIZE; i++) { __int128 tmp = (__int128)a.data[i] - (__int128)b.data[i] - borrow; dest.data[i] = (uint64_t)tmp; borrow = (tmp < 0) ? 1 : 0; } if (borrow != 0) { throw std::runtime_error("Underflow in sub_4096!"); } } //------------------------------------------- // ナイーブ乗算 (ヘルパー): mul_4096_naive //------------------------------------------- static void mul_4096_naive(const uint4096_t &x, const uint4096_t &y, uint4096_t &dest) { for (size_t i = 0; i < SIZE; i++) { dest.data[i] = 0ULL; } for (size_t i = 0; i < SIZE; i++) { __uint128_t carry = 0; for (size_t j = 0; j < SIZE - i; j++) { __uint128_t mul = (__uint128_t)x.data[i] * (__uint128_t)y.data[j] + (__uint128_t)dest.data[i + j] + carry; dest.data[i + j] = (uint64_t)mul; carry = (mul >> 64); } if (carry != 0) { throw std::runtime_error("Overflow in mul_4096_naive!"); } } } //------------------------------------------- // Karatsuba乗算(2分割) (一段のみの例) // x, y を 下位/上位 2048ビットに分割して // A0B0, A1B1, cross の3回乗算で合成する // // 実装メモ: // - 下位32ブロック + 上位32ブロック // - (A1B1 << 4096) は下位4096ビットには残らない → // A1B1≠0 なら overflow とする //------------------------------------------- static void karatsuba_4096(const uint4096_t &x, const uint4096_t &y, uint4096_t &dest) { // 1) 分割 uint4096_t xLow, xHigh; uint4096_t yLow, yHigh; splitInHalf(x, xLow, xHigh); splitInHalf(y, yLow, yHigh); // 2) A0B0 = xLow*yLow (再帰 or ナイーブ) uint4096_t A0B0; // ここではデモとしてナイーブ呼び出し mul_4096_naive(xLow, yLow, A0B0); // 3) A1B1 = xHigh*yHigh uint4096_t A1B1; mul_4096_naive(xHigh, yHigh, A1B1); // 4) sumX = xLow + xHigh uint4096_t sumX; add_4096(xLow, xHigh, sumX); // sumY = yLow + yHigh uint4096_t sumY; add_4096(yLow, yHigh, sumY); // 5) sumXY = sumX * sumY // (ここも本来は karatsuba再帰 or ナイーブ呼び出し) uint4096_t sumXY; mul_4096_naive(sumX, sumY, sumXY); // 6) cross = sumXY - A0B0 - A1B1 uint4096_t cross; { // sumXY - A0B0 uint4096_t tmp; sub_4096(sumXY, A0B0, tmp); // tmp - A1B1 sub_4096(tmp, A1B1, cross); } // 7) 合成 // result = A0B0 + (cross << 2048) + (A1B1 << 4096) // ただし A1B1 が非0なら => (A1B1 << 4096) で下位4096ビット以外にビットが出る → overflow // なので A1B1 ≠ 0 なら例外にする // cross << 2048 は「下位2048ビットを空ける」= 32ブロック分左シフト // 出力用のバッファを 0クリア for (size_t i = 0; i < SIZE; i++) { dest.data[i] = 0ULL; } // dest = A0B0 for (size_t i = 0; i < SIZE; i++) { dest.data[i] = A0B0.data[i]; } // dest += (cross << 2048) { uint4096_t crossShifted = cross << (64 * 32); // 2048 = 64*32 uint4096_t tmp; add_4096(dest, crossShifted, tmp); dest = tmp; } // A1B1 が非0 であれば overflow とする ( (A1B1 << 4096) は下位4096ビットに残らない ) bool nonzero = false; for (size_t i = 0; i < SIZE; i++) { if (A1B1.data[i] != 0ULL) { nonzero = true; break; } } if (nonzero) { // (A1B1 << 4096) は必ず下位4096ビットを超える → overflow throw std::runtime_error("Overflow in karatsuba_4096! (top 2048 bits were nonzero)"); } } static void splitInHalf(const uint4096_t &src, uint4096_t &lowPart, uint4096_t &highPart) { constexpr size_t HALF = SIZE / 2; lowPart.data.fill(0); highPart.data.fill(0); for (size_t i = 0; i < HALF; i++) { lowPart.data[i] = src.data[i]; } for (size_t i = HALF; i < SIZE; i++) { highPart.data[i - HALF] = src.data[i]; } } static void mergeInHalf(uint4096_t &dest, const uint4096_t &lowPart, const uint4096_t &highPart) { constexpr size_t HALF = SIZE / 2; dest.data.fill(0); for (size_t i = 0; i < HALF; i++) { dest.data[i] = lowPart.data[i]; } for (size_t i = HALF; i < SIZE; i++) { dest.data[i] = highPart.data[i - HALF]; } } // 加算 (オーバーフロー時は例外) uint4096_t operator+(const uint4096_t &other) const { uint4096_t result; add_4096(*this, other, result); return result; } // 加算の複合代入 uint4096_t &operator+=(const uint4096_t &other) { *this = *this + other; return *this; } // 減算 uint4096_t operator-(const uint4096_t &other) const { // 簡易: this >= other 前提チェック if (*this < other) throw std::runtime_error("Underflow in subtraction"); uint4096_t r; sub_4096(*this, other, r); return r; } // 減算の複合代入 uint4096_t &operator-=(const uint4096_t &other) { *this = *this - other; return *this; } uint4096_t operator*(const uint4096_t &other) const { uint4096_t r; mul_auto(*this, other, r); return r; } // 乗算の複合代入 uint4096_t &operator*=(const uint4096_t &other) { *this = *this * other; return *this; } // ビットシフト演算(左) // count は 0~4096 まで対応すると良い uint4096_t operator<<(unsigned int count) const { if (count >= 4096) { // 全ビットが0になる return uint4096_t(0); } uint4096_t result(*this); // ブロックシフト unsigned blockShift = count / 64; if (blockShift > 0) { // 上から下へシフト for (int i = SIZE - 1; i >= (int)blockShift; --i) { result.data[i] = result.data[i - blockShift]; } for (unsigned i = 0; i < blockShift; ++i) { result.data[i] = 0; } } // ビットシフト unsigned bitShift = count % 64; if (bitShift > 0) { for (int i = SIZE - 1; i >= 0; --i) { uint64_t highPart = (i > 0) ? result.data[i - 1] : 0; uint64_t val = result.data[i]; result.data[i] = (val << bitShift) | (highPart >> (64 - bitShift)); } } return result; } // ビットシフト演算(右) uint4096_t operator>>(unsigned int count) const { if (count >= 4096) { return uint4096_t(0); } uint4096_t result(*this); // ブロックシフト unsigned blockShift = count / 64; if (blockShift > 0) { for (unsigned i = 0; i < SIZE - blockShift; ++i) { result.data[i] = result.data[i + blockShift]; } for (int i = SIZE - blockShift; i < (int)SIZE; ++i) { result.data[i] = 0; } } // ビットシフト unsigned bitShift = count % 64; if (bitShift > 0) { for (size_t i = 0; i < SIZE; ++i) { uint64_t lowPart = (i + 1 < SIZE) ? result.data[i + 1] : 0; uint64_t val = result.data[i]; result.data[i] = (val >> bitShift) | (lowPart << (64 - bitShift)); } } return result; } // ビットシフトの複合代入 uint4096_t &operator<<=(unsigned int count) { *this = *this << count; return *this; } uint4096_t &operator>>=(unsigned int count) { *this = *this >> count; return *this; } // 除算・剰余をまとめて計算するヘルパー (2進法の long division) static void div_mod(const uint4096_t ÷nd, const uint4096_t &divisor, uint4096_t "ient, uint4096_t &remainder) { if (divisor == uint4096_t(0)) { throw std::runtime_error("Division by zero!"); } quotient = uint4096_t(0); remainder = uint4096_t(0); // 最高位ビットから下位ビットへ走査 (計 4096 bit) for (int i = 4095; i >= 0; --i) { // remainder を1ビット左へシフト remainder <<= 1; // dividend の i番目ビットを取り込む // i / 64 番目ブロック, i % 64 番目ビット uint64_t blockIndex = i / 64; uint64_t bitIndex = i % 64; uint64_t bitValue = (dividend.data[blockIndex] >> bitIndex) & 1ULL; remainder.data[0] |= bitValue; // 比較して大きければ引く if (remainder >= divisor) { remainder = remainder - divisor; // quotient の i番目ビットを立てる // => quotient += (1 << i) // iが大きい時はビットシフトだけでは処理できないのでブロックレベルで操作 quotient.data[blockIndex] |= (1ULL << bitIndex); } } } // 除算 (this / other) (オーバーフローは基本的には起こらない想定) uint4096_t operator/(const uint4096_t &other) const { uint4096_t q, r; div_mod(*this, other, q, r); return q; } // 剰余 (this % other) uint4096_t operator%(const uint4096_t &other) const { uint4096_t q, r; div_mod(*this, other, q, r); return r; } // 複合代入 (除算、剰余) uint4096_t &operator/=(const uint4096_t &other) { *this = *this / other; return *this; } uint4096_t &operator%=(const uint4096_t &other) { *this = *this % other; return *this; } // 比較 int compare(const uint4096_t &o) const { for (int i = SIZE - 1; i >= 0; i--) { if (data[i] < o.data[i]) return -1; if (data[i] > o.data[i]) return 1; } return 0; } bool operator==(const uint4096_t &o) const { return compare(o) == 0; } bool operator!=(const uint4096_t &o) const { return compare(o) != 0; } bool operator<(const uint4096_t &o) const { return compare(o) < 0; } bool operator<=(const uint4096_t &o) const { return compare(o) <= 0; } bool operator>(const uint4096_t &o) const { return compare(o) > 0; } bool operator>=(const uint4096_t &o) const { return compare(o) >= 0; } //-------------------------------------------- // 10進数で表示 (printDec) //-------------------------------------------- void printDec() const { // 2^4096 - 1 は 10^1233 くらいなので、1233 桁あれば十分 const int DECIMAL_DIGITS = 1233; std::vector digits(DECIMAL_DIGITS, 0); // 上位ブロック( data[63] )から順に走査し、2進→10進へ変換 for (int i = (int)SIZE - 1; i >= 0; --i) { for (int bit = 63; bit >= 0; --bit) { // 桁を 2倍してビットを足す uint32_t carry = (data[i] >> bit) & 1U; for (int k = DECIMAL_DIGITS - 1; k >= 0; --k) { uint64_t val = ((uint64_t)digits[k] << 1) + carry; digits[k] = (uint32_t)(val % 10); carry = (uint32_t)(val / 10); } } } // 先頭の0をスキップして出力 bool leading = true; for (size_t i = 0; i < digits.size(); ++i) { if (leading && digits[i] == 0) { continue; } leading = false; std::cout << (unsigned)digits[i]; } // 全部 0 だった場合 if (leading) { std::cout << "0"; } } void debugPrintBlocks() const { std::cout << "[ "; for (size_t i = 0; i < SIZE; i++) { std::cout << std::hex << data[i] << " "; } std::cout << "]" << std::dec << "\n"; } }; //========================================== // 4096ビットの符号付き整数型: int4096_t // - 内部的には (符号) + (絶対値) //========================================== class int4096_t { private: bool negative; // 負なら true, 正なら false uint4096_t magnitude; // 絶対値 public: // 1) デフォルト int4096_t() : negative(false), magnitude(0) { } // 2) 64ビット (符号付き) からの変換 (int64_t版のみ残す) int4096_t(int64_t val) { if (val < 0) { negative = true; magnitude = (uint64_t)(-val); } else { negative = false; magnitude = (uint64_t)val; } } // int4096_t から uint4096_t への明示的変換(絶対値を取る) explicit operator uint4096_t() const { return magnitude; } // ==, !=, <, <=, >, >= // 符号をまず比較し、符号が同じなら絶対値で比較 bool operator==(const int4096_t &other) const { return (negative == other.negative) && (magnitude == other.magnitude); } bool operator!=(const int4096_t &other) const { return !(*this == other); } bool operator<(const int4096_t &other) const { // 符号が異なるなら、負の方が小さい if (negative != other.negative) { return negative; // this が negative なら true } // 符号が同じなら、絶対値で比較 (負の場合は絶対値が大きい方が小さい) if (!negative) { // 共に非負 return (magnitude < other.magnitude); } else { // 共に負 // 例えば -10 < -5 は絶対値10 > 5 なので、逆 return (other.magnitude < magnitude); } } bool operator<=(const int4096_t &other) const { return !(*this > other); } bool operator>(const int4096_t &other) const { return other < *this; } bool operator>=(const int4096_t &other) const { return !(*this < other); } // 符号付き加算 int4096_t operator+(const int4096_t &other) const { // 符号が同じなら絶対値同士を加算し、符号はそのまま // 符号が異なるなら絶対値の大きい方から小さい方を引き、絶対値が大きい方の符号 int4096_t result; if (negative == other.negative) { // 同符号 result.negative = negative; // 加算 (オーバーフローは例外) try { result.magnitude = magnitude + other.magnitude; } catch (...) { throw std::runtime_error("Overflow in int4096_t addition!"); } } else { // 符号が異なる => 実質的には引き算 if (magnitude == other.magnitude) { // +X と -X の和は0 result.negative = false; result.magnitude = 0; } else if (magnitude > other.magnitude) { // this の絶対値が大きい => this の符号を継承 result.negative = negative; result.magnitude = magnitude - other.magnitude; // >=保証 } else { // other の絶対値が大きい => other の符号を継承 result.negative = other.negative; result.magnitude = other.magnitude - magnitude; } } return result; } int4096_t &operator+=(const int4096_t &other) { *this = *this + other; return *this; } // 符号付き減算 int4096_t operator-(const int4096_t &other) const { // A - B = A + (-B) int4096_t negB = other; negB.negative = !negB.negative; // 符号反転 return *this + negB; } int4096_t &operator-=(const int4096_t &other) { *this = *this - other; return *this; } // 符号付き乗算 int4096_t operator*(const int4096_t &other) const { int4096_t result; // 符号は XOR result.negative = (negative != other.negative); try { // ====== ここで "magnitude * other.magnitude" を呼ぶと ====== // 内部的には uint4096_t::operator*(...) が動作し、 // => mul_auto() => しきい値判定 => Karatsuba or ナイーブ // となるので、実質的に「大きい数は Karatsuba」「小さい数はナイーブ」に。 // ============================================================ result.magnitude = magnitude * other.magnitude; } catch (...) { throw std::runtime_error("Overflow in int4096_t multiplication!"); } // 0 なら符号は正 if (result.magnitude == uint4096_t(0)) { result.negative = false; } return result; } int4096_t &operator*=(const int4096_t &other) { *this = *this * other; return *this; } // 符号付き除算 int4096_t operator/(const int4096_t &other) const { if (other.magnitude == uint4096_t(0)) { throw std::runtime_error("Division by zero (int4096_t)!"); } int4096_t result; // 符号は this と other の xor result.negative = (negative != other.negative); // 絶対値同士で割り算 uint4096_t q = magnitude / other.magnitude; result.magnitude = q; return result; } int4096_t &operator/=(const int4096_t &other) { *this = *this / other; return *this; } // 符号付き剰余 int4096_t operator%(const int4096_t &other) const { if (other.magnitude == uint4096_t(0)) { throw std::runtime_error("Modulo by zero (int4096_t)!"); } // 商: q, 余り: r uint4096_t q, r; uint4096_t::div_mod(magnitude, other.magnitude, q, r); int4096_t result; // 符号は被除数(this) と同じ result.negative = negative; result.magnitude = r; return result; } int4096_t &operator%=(const int4096_t &other) { *this = *this % other; return *this; } // 10進数出力 void printDec() const { if (magnitude == uint4096_t(0)) { // 0 は符号に関係なく "0" 出力 std::cout << "0"; return; } if (negative) { std::cout << "-"; } magnitude.printDec(); } };