#include #include #include "conversion.hpp" #ifdef MCL_USE_XBYAK #include "fp_generator.hpp" #endif #include "fp_proto.hpp" #include "low_gmp.hpp" #ifdef _MSC_VER #pragma warning(disable : 4127) #endif namespace mcl { namespace fp { #ifdef MCL_USE_XBYAK FpGenerator *Op::createFpGenerator() { return new FpGenerator(); } void Op::destroyFpGenerator(FpGenerator *fg) { delete fg; } #else FpGenerator *Op::createFpGenerator() { return 0; } void Op::destroyFpGenerator(FpGenerator *) { } #endif /* use prefix if base conflicts with prefix */ inline const char *verifyStr(bool *isMinus, int *base, const std::string& str) { const char *p = str.c_str(); if (*p == '-') { *isMinus = true; p++; } else { *isMinus = false; } if (p[0] == '0') { if (p[1] == 'x') { *base = 16; p += 2; } else if (p[1] == 'b') { *base = 2; p += 2; } } if (*base == 0) *base = 10; if (*p == '\0') throw cybozu::Exception("fp:verifyStr:str is empty"); return p; } bool strToMpzArray(size_t *pBitSize, Unit *y, size_t maxBitSize, mpz_class& x, const std::string& str, int base) { bool isMinus; const char *p = verifyStr(&isMinus, &base, str); if (!gmp::setStr(x, p, base)) { throw cybozu::Exception("fp:strToMpzArray:bad format") << str; } const size_t bitSize = gmp::getBitSize(x); if (bitSize > maxBitSize) throw cybozu::Exception("fp:strToMpzArray:too large str") << str << bitSize << maxBitSize; if (pBitSize) *pBitSize = bitSize; gmp::getArray(y, (maxBitSize + UnitBitSize - 1) / UnitBitSize, x); return isMinus; } const char *ModeToStr(Mode mode) { switch (mode) { case FP_AUTO: return "auto"; case FP_GMP: return "gmp"; case FP_GMP_MONT: return "gmp_mont"; case FP_LLVM: return "llvm"; case FP_LLVM_MONT: return "llvm_mont"; case FP_XBYAK: return "xbyak"; default: throw cybozu::Exception("ModeToStr") << mode; } } Mode StrToMode(const std::string& s) { static const struct { const char *s; Mode mode; } tbl[] = { { "auto", FP_AUTO }, { "gmp", FP_GMP }, { "gmp_mont", FP_GMP_MONT }, { "llvm", FP_LLVM }, { "llvm_mont", FP_LLVM_MONT }, { "xbyak", FP_XBYAK }, }; for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl); i++) { if (s == tbl[i].s) return tbl[i].mode; } throw cybozu::Exception("StrToMode") << s; } template struct OpeFunc { static const size_t N = (bitSize + UnitBitSize - 1) / UnitBitSize; static inline void set_mpz_t(mpz_t& z, const Unit* p, int n = (int)N) { int s = n; while (s > 0) { if (p[s - 1]) break; s--; } z->_mp_alloc = n; z->_mp_size = s; z->_mp_d = (mp_limb_t*)const_cast(p); } static inline void set_zero(mpz_t& z, Unit *p, size_t n) { z->_mp_alloc = (int)n; z->_mp_size = 0; z->_mp_d = (mp_limb_t*)p; } static inline void fp_clearC(Unit *x) { clearArray(x, 0, N); } static inline void fp_copyC(Unit *y, const Unit *x) { copyArray(y, x, N); } static inline void fp_addPC(Unit *z, const Unit *x, const Unit *y, const Unit *p) { if (low_add(z, x, y)) { low_sub(z, z, p); return; } Unit tmp[N]; if (low_sub(tmp, z, p) == 0) { memcpy(z, tmp, sizeof(tmp)); } } static inline void fp_subPC(Unit *z, const Unit *x, const Unit *y, const Unit *p) { if (low_sub(z, x, y)) { low_add(z, z, p); } } /* z[N * 2] <- x[N * 2] + y[N * 2] mod p[N] << (N * UnitBitSize) */ static inline void fpDbl_addPC(Unit *z, const Unit *x, const Unit *y, const Unit *p) { if (low_add(z, x, y)) { low_sub(z + N, z + N, p); return; } Unit tmp[N]; if (low_sub(tmp, z + N, p) == 0) { memcpy(z + N, tmp, sizeof(tmp)); } } static inline void fpDbl_subPC(Unit *z, const Unit *x, const Unit *y, const Unit *p) { if (low_sub(z, x, y)) { low_add(z + N, z + N, p); } } // z[N] <- x[N] + y[N] without carry static inline void fp_addNCC(Unit *z, const Unit *x, const Unit *y) { low_add(z, x, y); } static inline void fp_subNCC(Unit *z, const Unit *x, const Unit *y) { low_sub(z, x, y); } // z[N + 1] <- x[N] * y static inline void fp_mul_UnitPreC(Unit *z, const Unit *x, Unit y) { low_mul_Unit(z, x, y); } // z[N * 2] <- x[N] * y[N] static inline void fpDbl_mulPreC(Unit *z, const Unit *x, const Unit *y) { low_mul(z, x, y); } // y[N * 2] <- x[N]^2 static inline void fpDbl_sqrPreC(Unit *y, const Unit *x) { low_sqr(y, x); } // y[N] <- x[N + 1] mod p[N] static inline void fpN1_modPC(Unit *y, const Unit *x, const Unit *p) { low_N1_mod(y, x, p); } // y[N] <- x[N * 2] mod p[N] static inline void fpDbl_modPC(Unit *y, const Unit *x, const Unit *p) { low_mod(y, x, p); } // z[N] <- mont(x[N], y[N]) static inline void fp_montPUC(Unit *z, const Unit *x, const Unit *y, const Unit *p, Unit rp) { Unit buf[N * 2 + 2]; Unit *c = buf; low_mul_Unit(c, x, y[0]); // x * y[0] Unit q = c[0] * rp; Unit t[N + 2]; low_mul_Unit(t, p, q); // p * q t[N + 1] = 0; // always zero c[N + 1] = low_add(c, c, t); c++; for (size_t i = 1; i < N; i++) { low_mul_Unit(t, x, y[i]); c[N + 1] = low_add(c, c, t); q = c[0] * rp; low_mul_Unit(t, p, q); low_add(c, c, t); c++; } if (c[N]) { low_sub(z, c, p); } else { if (low_sub(z, c, p)) { memcpy(z, c, N * sizeof(Unit)); } } } // z[N] <- montRed(xy[N * 2]) static inline void fp_montRedPUC(Unit *z, const Unit *xy, const Unit *p, Unit rp) { Unit t[N * 2]; Unit buf[N * 2 + 1]; clearArray(t, N + 1, N * 2); Unit *c = buf; Unit q = xy[0] * rp; low_mul_Unit(t, p, q); buf[N * 2] = low_add(buf, xy, t); c++; for (size_t i = 1; i < N; i++) { q = c[0] * rp; low_mul_Unit(t, p, q); // QQQ mpn_add_n((mp_limb_t*)c, (const mp_limb_t*)c, (const mp_limb_t*)t, N * 2 + 1 - i); c++; } if (c[N]) { low_sub(z, c, p); } else { if (low_sub(z, c, p)) { memcpy(z, c, N * sizeof(Unit)); } } } static inline void fp_invOpC(Unit *y, const Unit *x, const Op& op) { mpz_class my; mpz_t mx, mp; set_mpz_t(mx, x); set_mpz_t(mp, op.p); mpz_invert(my.get_mpz_t(), mx, mp); gmp::getArray(y, N, my); } /* inv(xR) = (1/x)R^-1 -toMont-> 1/x -toMont-> (1/x)R */ static void fp_invMontOpC(Unit *y, const Unit *x, const Op& op) { fp_invOpC(y, x, op); op.fp_mul(y, y, op.R3); } static inline bool fp_isZeroC(const Unit *x) { return isZeroArray(x, N); } static inline void fp_negC(Unit *y, const Unit *x, const Unit *p) { if (fp_isZeroC(x)) { if (x != y) fp_clearC(y); return; } fp_subPC(y, p, x, p); } }; #ifdef MCL_USE_LLVM #define SET_OP_LLVM(n) \ if (mode == FP_LLVM || mode == FP_LLVM_MONT) { \ fp_add = mcl_fp_add ## n; \ fp_sub = mcl_fp_sub ## n; \ if (!isFullBit) { \ fp_addNC = mcl_fp_addNC ## n; \ fp_subNC = mcl_fp_subNC ## n; \ } \ fpDbl_mulPre = mcl_fpDbl_mulPre ## n; \ fp_mul_UnitPre = mcl_fp_mul_UnitPre ## n; \ fpDbl_sqrPre = mcl_fpDbl_sqrPre ## n; \ montPU = mcl_fp_mont ## n; \ montRedPU = mcl_fp_montRed ## n; \ } #define SET_OP_DBL_LLVM(n, n2) \ if (mode == FP_LLVM || mode == FP_LLVM_MONT) { \ fpDbl_addP = mcl_fpDbl_add ## n; \ fpDbl_subP = mcl_fpDbl_sub ## n; \ if (!isFullBit) { \ fpDbl_addNC = mcl_fp_addNC ## n2; \ fpDbl_subNC = mcl_fp_subNC ## n2; \ } \ } #else #define SET_OP_LLVM(n) #define SET_OP_DBL_LLVM(n, n2) #endif #define SET_OP(n) \ N = n / UnitBitSize; \ fp_isZero = OpeFunc::fp_isZeroC; \ fp_clear = OpeFunc::fp_clearC; \ fp_copy = OpeFunc::fp_copyC; \ fp_neg = OpeFunc::fp_negC; \ if (isMont) { \ fp_invOp = OpeFunc::fp_invMontOpC; \ } else { \ fp_invOp = OpeFunc::fp_invOpC; \ } \ fp_add = OpeFunc::fp_addPC; \ fp_sub = OpeFunc::fp_subPC; \ fpDbl_addP = OpeFunc::fpDbl_addPC; \ fpDbl_subP = OpeFunc::fpDbl_subPC; \ if (isFullBit) { \ fp_addNC = 0; \ fp_subNC = 0; \ fpDbl_addNC = 0; \ fpDbl_subNC = 0; \ } else { \ fp_addNC = OpeFunc::fp_addNCC; \ fp_subNC = OpeFunc::fp_subNCC; \ fpDbl_addNC = OpeFunc::fp_addNCC; \ fpDbl_subNC = OpeFunc::fp_subNCC; \ } \ fp_mul_UnitPre = OpeFunc::fp_mul_UnitPreC; \ fpN1_modP = OpeFunc::fpN1_modPC; \ fpDbl_mulPre = OpeFunc::fpDbl_mulPreC; \ fpDbl_sqrPre = OpeFunc::fpDbl_sqrPreC; \ fpDbl_modP = OpeFunc::fpDbl_modPC; \ montPU = OpeFunc::fp_montPUC; \ montRedPU = OpeFunc::fp_montRedPUC; \ SET_OP_LLVM(n) #ifdef MCL_USE_XBYAK inline void invOpForMontC(Unit *y, const Unit *x, const Op& op) { Unit r[maxOpUnitSize]; int k = op.fp_preInv(r, x); /* S = UnitBitSize xr = 2^k R = 2^(N * S) get r2^(-k)R^2 = r 2^(N * S * 2 - k) */ op.fp_mul(y, r, op.invTbl.data() + k * op.N); } static void initInvTbl(Op& op) { const size_t N = op.N; const Unit *p = op.p; const size_t invTblN = N * sizeof(Unit) * 8 * 2; op.invTbl.resize(invTblN * N); Unit *tbl = op.invTbl.data() + (invTblN - 1) * N; Unit t[maxOpUnitSize] = {}; t[0] = 2; op.toMont(tbl, t); for (size_t i = 0; i < invTblN - 1; i++) { op.fp_add(tbl - N, tbl, tbl, p); tbl -= N; } } #endif static void initForMont(Op& op, const Unit *p, Mode mode) { const size_t N = op.N; { mpz_class t = 1, R; gmp::getArray(op.one, N, t); R = (t << (N * UnitBitSize)) % op.mp; t = (R * R) % op.mp; gmp::getArray(op.R2, N, t); t = (R * R * R) % op.mp; gmp::getArray(op.R3, N, t); } op.rp = getMontgomeryCoeff(p[0]); if (mode != FP_XBYAK) return; #ifdef MCL_USE_XBYAK FpGenerator *fg = op.fg; if (fg == 0) return; fg->init(op); if (op.isMont && N <= 4) { op.fp_invOp = &invOpForMontC; initInvTbl(op); } #endif } void Op::init(const std::string& mstr, int base, size_t maxBitSize, Mode mode) { #if 0 fprintf(stderr, "mode=%s, isMont=%d, maxBitSize=%d" #ifdef MCL_USE_XBYAK " MCL_USE_XBYAK" #endif #ifdef MCL_USE_LLVM " MCL_USE_LLVM" #endif "\n", ModeToStr(mode), isMont, (int)maxBitSize); #endif if (maxBitSize > MCL_MAX_OP_BIT_SIZE) { throw cybozu::Exception("Op:init:too large maxBitSize") << maxBitSize << MCL_MAX_OP_BIT_SIZE; } cybozu::disable_warning_unused_variable(mode); bool isMinus = fp::strToMpzArray(&bitSize, p, maxBitSize, mp, mstr, base); if (isMinus) throw cybozu::Exception("Op:init:mstr is minus") << mstr; if (mp == 0) throw cybozu::Exception("Op:init:mstr is zero") << mstr; isFullBit = (bitSize % UnitBitSize) == 0; const size_t roundBit = (bitSize + UnitBitSize - 1) & ~(UnitBitSize - 1); primeMode = PM_GENERIC; #if defined(MCL_USE_LLVM) || defined(MCL_USE_XBYAK) if ((mode == FP_AUTO || mode == FP_LLVM || mode == FP_XBYAK) && mp == mpz_class("0xfffffffffffffffffffffffffffffffeffffffffffffffff")) { primeMode = PM_NICT_P192; isMont = false; isFastMod = true; } if ((mode == FP_AUTO || mode == FP_LLVM || mode == FP_XBYAK) && mp == mpz_class("0x1ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")) { primeMode = PM_NICT_P521; isMont = false; isFastMod = true; } #endif switch (roundBit) { case 64: SET_OP(64); SET_OP_DBL_LLVM(64, 128); break; case 128: SET_OP(128); SET_OP_DBL_LLVM(128, 256); break; case 192: SET_OP(192); SET_OP_DBL_LLVM(192, 384); break; case 256: SET_OP(256); SET_OP_DBL_LLVM(256, 512); break; case 320: SET_OP(320); break; case 384: SET_OP(384); break; case 448: SET_OP(448); break; case 512: SET_OP(512); // QQQ : need refactor for large prime #if MCL_MAX_OP_BIT_SIZE == 768 SET_OP_DBL_LLVM(512, 1024); #endif break; #if CYBOZU_OS_BIT == 64 case 576: SET_OP(576); #if MCL_MAX_OP_BIT_SIZE == 768 SET_OP_DBL_LLVM(576, 1152); #endif break; #if MCL_MAX_OP_BIT_SIZE == 768 case 640: SET_OP(640); SET_OP_DBL_LLVM(640, 1280); break; case 704: SET_OP(704); SET_OP_DBL_LLVM(704, 1408); break; case 768: SET_OP(768); SET_OP_DBL_LLVM(768, 1536); break; #endif #else case 32: SET_OP(32); SET_OP_DBL_LLVM(32, 64); break; case 96: SET_OP(96); SET_OP_DBL_LLVM(96, 192); break; case 160: SET_OP(160); SET_OP_DBL_LLVM(160, 320); break; case 224: SET_OP(224); SET_OP_DBL_LLVM(224, 448); break; case 288: SET_OP(288); break; case 352: SET_OP(352); break; case 416: SET_OP(416); break; case 480: SET_OP(480); break; case 544: SET_OP(544); break; #endif default: throw cybozu::Exception("Op::init:not:support") << mstr; } #ifdef MCL_USE_LLVM if (primeMode == PM_NICT_P192) { fp_mul = &mcl_fp_mul_NIST_P192; fp_sqr = &mcl_fp_sqr_NIST_P192; fpDbl_mod = &mcl_fpDbl_mod_NIST_P192; } if (primeMode == PM_NICT_P521) { fpDbl_mod = &mcl_fpDbl_mod_NIST_P521; } #endif fp::initForMont(*this, p, mode); sq.set(mp); } void arrayToStr(std::string& str, const Unit *x, size_t n, int base, bool withPrefix) { switch (base) { case 0: case 10: { mpz_class t; gmp::setArray(t, x, n); gmp::getStr(str, t, 10); } return; case 16: mcl::fp::toStr16(str, x, n, withPrefix); return; case 2: mcl::fp::toStr2(str, x, n, withPrefix); return; default: throw cybozu::Exception("fp:arrayToStr:bad base") << base; } } void copyAndMask(Unit *y, const void *x, size_t xByteSize, const Op& op, bool doMask) { const size_t fpByteSize = sizeof(Unit) * op.N; if (xByteSize > fpByteSize) { if (!doMask) throw cybozu::Exception("fp:copyAndMask:bad size") << xByteSize << fpByteSize; xByteSize = fpByteSize; } memcpy(y, x, xByteSize); memset((char *)y + xByteSize, 0, fpByteSize - xByteSize); if (!doMask) { if (isGreaterOrEqualArray(y, op.p, op.N)) throw cybozu::Exception("fp:copyAndMask:large x"); return; } maskArray(y, op.N, op.bitSize - 1); assert(isLessArray(y, op.p, op.N)); } static bool isInUint64(uint64_t *pv, const fp::Block& b) { assert(fp::UnitBitSize == 32 || fp::UnitBitSize == 64); const size_t start = 64 / fp::UnitBitSize; for (size_t i = start; i < b.n; i++) { if (b.p[i]) return false; } #if CYBOZU_OS_BIT == 32 *pv = b.p[0] | (uint64_t(b.p[1]) << 32); #else *pv = b.p[0]; #endif return true; } uint64_t getUint64(bool *pb, const fp::Block& b) { uint64_t v; if (isInUint64(&v, b)) { if (pb) *pb = true; return v; } if (!pb) { std::string str; arrayToStr(str, b.p, b.n, 10, false); throw cybozu::Exception("fp::getUint64:large value") << str; } *pb = false; return 0; } #ifdef _MSC_VER #pragma warning(push) #pragma warning(disable : 4146) #endif int64_t getInt64(bool *pb, fp::Block& b, const fp::Op& op) { bool isNegative = false; if (fp::isGreaterArray(b.p, op.half, op.N)) { op.fp_neg(b.v_, b.p, op.p); b.p = b.v_; isNegative = true; } uint64_t v; if (fp::isInUint64(&v, b)) { const uint64_t c = uint64_t(1) << 63; if (isNegative) { if (v <= c) { // include c if (pb) *pb = true; // -1 << 63 if (v == c) return int64_t(-9223372036854775807ll - 1); return int64_t(-v); } } else { if (v < c) { // not include c if (pb) *pb = true; return int64_t(v); } } } if (!pb) { std::string str; arrayToStr(str, b.p, b.n, 10, false); throw cybozu::Exception("fp::getInt64:large value") << str << isNegative; } *pb = false; return 0; } #ifdef _WIN32 #pragma warning(pop) #endif void Op::initFp2(int _xi_a) { this->xi_a = _xi_a; // if (N * UnitBitSize != 256) throw cybozu::Exception("Op2:init:not support size") << N; } } } // mcl::fp