diff options
author | MITSUNARI Shigeo <herumi@nifty.com> | 2018-11-13 13:59:20 +0800 |
---|---|---|
committer | MITSUNARI Shigeo <herumi@nifty.com> | 2018-11-13 13:59:20 +0800 |
commit | ad282ed284a34529694b0a9cb24535e56a673a40 (patch) | |
tree | 4b12fa71b13aa43271f8e8de25ea839e4c725ac8 | |
parent | b087e6f1f4b8018e9b0f05a21fc261a5cf2f0f58 (diff) | |
download | tangerine-mcl-ad282ed284a34529694b0a9cb24535e56a673a40.tar.gz tangerine-mcl-ad282ed284a34529694b0a9cb24535e56a673a40.tar.zst tangerine-mcl-ad282ed284a34529694b0a9cb24535e56a673a40.zip |
refactor fp_generator and the argument of Fp2::init() is changed
-rw-r--r-- | include/mcl/bn.hpp | 8 | ||||
-rw-r--r-- | include/mcl/fp.hpp | 16 | ||||
-rw-r--r-- | include/mcl/fp_tower.hpp | 16 | ||||
-rw-r--r-- | src/fp_generator.hpp | 89 |
4 files changed, 50 insertions, 79 deletions
diff --git a/include/mcl/bn.hpp b/include/mcl/bn.hpp index 7f7a689..7862010 100644 --- a/include/mcl/bn.hpp +++ b/include/mcl/bn.hpp @@ -1043,12 +1043,12 @@ struct Param { assert((p % 6) == 1); r = local::evalPoly(z, rCoff); } - Fp::init(pb, p, mode); - if (!*pb) return; Fr::init(pb, r, mode); if (!*pb) return; - Fp2::init(cp.xi_a); - Fp2 xi(cp.xi_a, 1); + Fp::init(pb, cp.xi_a, p, mode); + if (!*pb) return; + Fp2::init(); + const Fp2 xi(cp.xi_a, 1); g2 = Fp2::get_gTbl()[0]; g3 = Fp2::get_gTbl()[3]; if (cp.isMtype) { diff --git a/include/mcl/fp.hpp b/include/mcl/fp.hpp index 9fd6d74..a0a7e85 100644 --- a/include/mcl/fp.hpp +++ b/include/mcl/fp.hpp @@ -120,10 +120,14 @@ public: } printf("\n"); } - static inline void init(bool *pb, const mpz_class& _p, fp::Mode mode = fp::FP_AUTO, int xi_a = 0) + /* + xi_a is used for Fp2::mul_xi(), where xi = xi_a + i and i^2 = -1 + if xi_a = 0 then asm functions for Fp2 are not generated. + */ + static inline void init(bool *pb, int xi_a, const mpz_class& p, fp::Mode mode = fp::FP_AUTO) { assert(maxBitSize <= MCL_MAX_BIT_SIZE); - *pb = op_.init(_p, maxBitSize, xi_a, mode); + *pb = op_.init(p, maxBitSize, xi_a, mode); if (!*pb) return; { // set oneRep FpT& one = *reinterpret_cast<FpT*>(op_.oneRep); @@ -151,12 +155,16 @@ public: #endif *pb = true; } - static inline void init(bool *pb, const char *mstr, fp::Mode mode = fp::FP_AUTO, int xi_a = 0) + static inline void init(bool *pb, const mpz_class& p, fp::Mode mode = fp::FP_AUTO) + { + init(pb, 0, p, mode); + } + static inline void init(bool *pb, const char *mstr, fp::Mode mode = fp::FP_AUTO) { mpz_class p; gmp::setStr(pb, p, mstr); if (!*pb) return; - init(pb, p, mode, xi_a); + init(pb, p, mode); } static inline size_t getModulo(char *buf, size_t bufSize) { diff --git a/include/mcl/fp_tower.hpp b/include/mcl/fp_tower.hpp index 8267caa..9a83442 100644 --- a/include/mcl/fp_tower.hpp +++ b/include/mcl/fp_tower.hpp @@ -202,7 +202,6 @@ class Fp2T : public fp::Serializable<Fp2T<_Fp>, typedef fp::Unit Unit; typedef FpDblT<Fp> FpDbl; typedef Fp2DblT<Fp> Fp2Dbl; - static uint32_t xi_a_; static const size_t gN = 5; /* g = xi^((p - 1) / 6) @@ -373,12 +372,12 @@ public: } } - static uint32_t get_xi_a() { return xi_a_; } - static void init(uint32_t xi_a) + static uint32_t get_xi_a() { return Fp::getOp().xi_a; } + static void init() { // assert(Fp::maxSize <= 256); - xi_a_ = xi_a; mcl::fp::Op& op = Fp::op_; + assert(op.xi_a); add = (void (*)(Fp2T& z, const Fp2T& x, const Fp2T& y))op.fp2_addA_; if (add == 0) add = fp2_addC; sub = (void (*)(Fp2T& z, const Fp2T& x, const Fp2T& y))op.fp2_subA_; @@ -402,7 +401,7 @@ public: sqr = (void (*)(Fp2T& y, const Fp2T& x))op.fp2_sqrA_; if (sqr == 0) sqr = fp2_sqrC; op.fp2_inv = fp2_invW; - if (xi_a == 1) { + if (op.xi_a == 1) { /* current fp_generator.hpp generates mul_xi for xi_a = 1 */ @@ -417,7 +416,7 @@ public: FpDblT<Fp>::init(); Fp2DblT<Fp>::init(); // call init before Fp2::pow because FpDbl is used in Fp2T - const Fp2T xi(xi_a, 1); + const Fp2T xi(op.xi_a, 1); const mpz_class& p = Fp::getOp().mp; Fp2T::pow(g[0], xi, (p - 1) / 6); // g = xi^((p-1)/6) for (size_t i = 1; i < gN; i++) { @@ -579,9 +578,9 @@ private: const Fp& a = x.a; const Fp& b = x.b; Fp t; - Fp::mulUnit(t, a, xi_a_); + Fp::mulUnit(t, a, Fp::getOp().xi_a); t -= b; - Fp::mulUnit(y.b, b, xi_a_); + Fp::mulUnit(y.b, b, Fp::getOp().xi_a); y.b += a; y.a = t; } @@ -765,7 +764,6 @@ struct Fp2DblT { template<class Fp> void (*Fp2DblT<Fp>::mulPre)(Fp2DblT&, const Fp2T<Fp>&, const Fp2T<Fp>&); template<class Fp> void (*Fp2DblT<Fp>::sqrPre)(Fp2DblT&, const Fp2T<Fp>&); -template<class Fp> uint32_t Fp2T<Fp>::xi_a_; template<class Fp> Fp2T<Fp> Fp2T<Fp>::g[Fp2T<Fp>::gN]; template<class Fp> Fp2T<Fp> Fp2T<Fp>::g2[Fp2T<Fp>::gN]; template<class Fp> Fp2T<Fp> Fp2T<Fp>::g3[Fp2T<Fp>::gN]; diff --git a/src/fp_generator.hpp b/src/fp_generator.hpp index 58140e9..dfb73d6 100644 --- a/src/fp_generator.hpp +++ b/src/fp_generator.hpp @@ -203,22 +203,6 @@ struct FpGenerator : Xbyak::CodeGenerator { int pn_; int FpByte_; bool isFullBit_; - // add/sub without carry. return true if overflow - typedef bool (*bool3op)(uint64_t*, const uint64_t*, const uint64_t*); - - // add/sub with mod -// typedef void (*void3op)(uint64_t*, const uint64_t*, const uint64_t*); - - // mul without carry. return top of z - typedef uint64_t (*uint3opI)(uint64_t*, const uint64_t*, uint64_t); - - // neg - typedef void (*void2op)(uint64_t*, const uint64_t*); - - // preInv - typedef int (*int2op)(uint64_t*, const uint64_t*); - void4u mul_; -// uint3opI mulUnit_; /* @param op [in] ; use op.p, op.N, op.isFullBit @@ -253,8 +237,6 @@ struct FpGenerator : Xbyak::CodeGenerator { , rp_(0) , pn_(0) , FpByte_(0) - , mul_(0) -// , mulUnit_(0) { useMulx_ = cpu_.has(Xbyak::util::Cpu::tBMI2); useAdx_ = cpu_.has(Xbyak::util::Cpu::tADX); @@ -264,6 +246,7 @@ struct FpGenerator : Xbyak::CodeGenerator { reset(); // reset jit code for reuse setProtectModeRW(); // read/write memory init_inner(op); + printf("code size=%d\n", (int)getSize()); setProtectModeRE(); // set read/exec memory } private: @@ -271,10 +254,6 @@ private: { op_ = &op; if (!cpu_.has(Xbyak::util::Cpu::tAVX)) return; - /* - first 4096-byte is data area - remain is code area - */ L(pL_); p_ = reinterpret_cast<const uint64_t*>(getCurr()); for (size_t i = 0; i < op.N; i++) { @@ -285,9 +264,7 @@ private: FpByte_ = int(op.maxN * sizeof(uint64_t)); isFullBit_ = op.isFullBit; // printf("p=%p, pn_=%d, isFullBit_=%d\n", p_, pn_, isFullBit_); - // code from here - setSize(4096); - assert((getCurr<size_t>() & 4095) == 0); + op.fp_addPre = gen_addSubPre(true, pn_); op.fp_subPre = gen_addSubPre(false, pn_); op.fp_subA_ = gen_fp_sub(); @@ -297,42 +274,30 @@ private: op.fp_negA_ = gen_fp_neg(); - void* func = 0; - // setup fp_tower - op.fp2_mulNF = 0; - func = gen_fpDbl_add(); - if (func) op.fpDbl_addA_ = reinterpret_cast<void3u>(func); - func = gen_fpDbl_sub(); - if (func) op.fpDbl_subA_ = reinterpret_cast<void3u>(func); + op.fpDbl_addA_ = gen_fpDbl_add(); + op.fpDbl_subA_ = gen_fpDbl_sub(); op.fpDbl_addPre = gen_addSubPre(true, pn_ * 2); op.fpDbl_subPre = gen_addSubPre(false, pn_ * 2); - func = gen_fpDbl_mulPre(); - if (func) op.fpDbl_mulPreA_ = reinterpret_cast<void3u>(func); - - func = gen_fpDbl_mod(op); - if (func) op.fpDbl_modA_ = reinterpret_cast<void2u>(func); - - func = gen_fpDbl_sqrPre(op); - if (func) op.fpDbl_sqrPreA_ = reinterpret_cast<void2u>(func); + op.fpDbl_mulPreA_ = gen_fpDbl_mulPre(); + op.fpDbl_sqrPreA_ = gen_fpDbl_sqrPre(); + op.fpDbl_modA_ = gen_fpDbl_mod(op); - func = gen_mul(); - if (func) { - op.fp_mul = reinterpret_cast<void4u>(func); // used in toMont/fromMont - op.fp_mulA_ = reinterpret_cast<void3u>(func); - } - func = gen_sqr(); - if (func) { - op.fp_sqrA_ = reinterpret_cast<void2u>(func); + op.fp_mulA_ = gen_mul(); + if (op.fp_mulA_) { + op.fp_mul = reinterpret_cast<void4u>(op.fp_mulA_); // used in toMont/fromMont } + op.fp_sqrA_ = gen_sqr(); if (op.primeMode != PM_NIST_P192 && op.N <= 4) { // support general op.N but not fast for op.N > 4 align(16); op.fp_preInv = getCurr<int2u>(); gen_preInv(); } + if (op.xi_a == 0) return; // Fp2 is not used op.fp2_addA_ = gen_fp2_add(); op.fp2_subA_ = gen_fp2_sub(); op.fp2_negA_ = gen_fp2_neg(); + op.fp2_mulNF = 0; op.fp2Dbl_mulPreA_ = gen_fp2Dbl_mulPre(); op.fp2Dbl_sqrPreA_ = gen_fp2Dbl_sqrPre(); op.fp2_mulA_ = gen_fp2_mul(); @@ -668,10 +633,10 @@ private: outLocalLabel(); return func; } - void* gen_fpDbl_add() + void3u gen_fpDbl_add() { align(16); - void* func = getCurr<void*>(); + void3u func = getCurr<void3u>(); if (pn_ <= 4) { int tn = pn_ * 2 + (isFullBit_ ? 1 : 0); StackFrame sf(this, 3, tn); @@ -696,10 +661,10 @@ private: } return 0; } - void* gen_fpDbl_sub() + void3u gen_fpDbl_sub() { align(16); - void* func = getCurr<void*>(); + void3u func = getCurr<void3u>(); if (pn_ <= 4) { int tn = pn_ * 2; StackFrame sf(this, 3, tn); @@ -800,10 +765,10 @@ private: mov(ptr [pz + (pn_ - 1) * 8], *t0); return func; } - void* gen_mul() + void3u gen_mul() { align(16); - void* func = getCurr<void*>(); + void3u func = getCurr<void3u>(); if (op_->primeMode == PM_NIST_P192) { StackFrame sf(this, 3, 10 | UseRDX, 8 * 6); mulPre3(rsp, sf.p[1], sf.p[2], sf.t); @@ -1113,10 +1078,10 @@ private: vmovq(z, xm0); store_mr(z, Pack(t10, t9, t8, t4)); } - void* gen_fpDbl_mod(const fp::Op& op) + void2u gen_fpDbl_mod(const fp::Op& op) { align(16); - void* func = getCurr<void*>(); + void2u func = getCurr<void2u>(); if (op.primeMode == PM_NIST_P192) { StackFrame sf(this, 2, 6 | UseRDX); fpDbl_mod_NIST_P192(sf.p[0], sf.p[1], sf.t); @@ -1159,10 +1124,10 @@ private: } return 0; } - void* gen_sqr() + void2u gen_sqr() { align(16); - void* func = getCurr<void*>(); + void2u func = getCurr<void2u>(); if (op_->primeMode == PM_NIST_P192) { StackFrame sf(this, 3, 10 | UseRDX, 6 * 8); Pack t = sf.t; @@ -2267,10 +2232,10 @@ private: vmovq(z, xm0); store_mr(z, zp); } - void* gen_fpDbl_sqrPre(const fp::Op&/* op */) + void2u gen_fpDbl_sqrPre() { align(16); - void* func = getCurr<void*>(); + void2u func = getCurr<void2u>(); if (pn_ == 2 && useMulx_) { StackFrame sf(this, 2, 7 | UseRDX); sqrPre2(sf.p[0], sf.p[1], sf.t); @@ -2308,10 +2273,10 @@ private: return func; #endif } - void* gen_fpDbl_mulPre() + void3u gen_fpDbl_mulPre() { align(16); - void* func = getCurr<void*>(); + void3u func = getCurr<void3u>(); if (pn_ == 2 && useMulx_) { StackFrame sf(this, 3, 5 | UseRDX); mulPre2(sf.p[0], sf.p[1], sf.p[2], sf.t); |