aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMITSUNARI Shigeo <herumi@nifty.com>2018-11-13 13:59:20 +0800
committerMITSUNARI Shigeo <herumi@nifty.com>2018-11-13 13:59:20 +0800
commitad282ed284a34529694b0a9cb24535e56a673a40 (patch)
tree4b12fa71b13aa43271f8e8de25ea839e4c725ac8
parentb087e6f1f4b8018e9b0f05a21fc261a5cf2f0f58 (diff)
downloadtangerine-mcl-ad282ed284a34529694b0a9cb24535e56a673a40.tar.gz
tangerine-mcl-ad282ed284a34529694b0a9cb24535e56a673a40.tar.zst
tangerine-mcl-ad282ed284a34529694b0a9cb24535e56a673a40.zip
refactor fp_generator and the argument of Fp2::init() is changed
-rw-r--r--include/mcl/bn.hpp8
-rw-r--r--include/mcl/fp.hpp16
-rw-r--r--include/mcl/fp_tower.hpp16
-rw-r--r--src/fp_generator.hpp89
4 files changed, 50 insertions, 79 deletions
diff --git a/include/mcl/bn.hpp b/include/mcl/bn.hpp
index 7f7a689..7862010 100644
--- a/include/mcl/bn.hpp
+++ b/include/mcl/bn.hpp
@@ -1043,12 +1043,12 @@ struct Param {
assert((p % 6) == 1);
r = local::evalPoly(z, rCoff);
}
- Fp::init(pb, p, mode);
- if (!*pb) return;
Fr::init(pb, r, mode);
if (!*pb) return;
- Fp2::init(cp.xi_a);
- Fp2 xi(cp.xi_a, 1);
+ Fp::init(pb, cp.xi_a, p, mode);
+ if (!*pb) return;
+ Fp2::init();
+ const Fp2 xi(cp.xi_a, 1);
g2 = Fp2::get_gTbl()[0];
g3 = Fp2::get_gTbl()[3];
if (cp.isMtype) {
diff --git a/include/mcl/fp.hpp b/include/mcl/fp.hpp
index 9fd6d74..a0a7e85 100644
--- a/include/mcl/fp.hpp
+++ b/include/mcl/fp.hpp
@@ -120,10 +120,14 @@ public:
}
printf("\n");
}
- static inline void init(bool *pb, const mpz_class& _p, fp::Mode mode = fp::FP_AUTO, int xi_a = 0)
+ /*
+ xi_a is used for Fp2::mul_xi(), where xi = xi_a + i and i^2 = -1
+ if xi_a = 0 then asm functions for Fp2 are not generated.
+ */
+ static inline void init(bool *pb, int xi_a, const mpz_class& p, fp::Mode mode = fp::FP_AUTO)
{
assert(maxBitSize <= MCL_MAX_BIT_SIZE);
- *pb = op_.init(_p, maxBitSize, xi_a, mode);
+ *pb = op_.init(p, maxBitSize, xi_a, mode);
if (!*pb) return;
{ // set oneRep
FpT& one = *reinterpret_cast<FpT*>(op_.oneRep);
@@ -151,12 +155,16 @@ public:
#endif
*pb = true;
}
- static inline void init(bool *pb, const char *mstr, fp::Mode mode = fp::FP_AUTO, int xi_a = 0)
+ static inline void init(bool *pb, const mpz_class& p, fp::Mode mode = fp::FP_AUTO)
+ {
+ init(pb, 0, p, mode);
+ }
+ static inline void init(bool *pb, const char *mstr, fp::Mode mode = fp::FP_AUTO)
{
mpz_class p;
gmp::setStr(pb, p, mstr);
if (!*pb) return;
- init(pb, p, mode, xi_a);
+ init(pb, p, mode);
}
static inline size_t getModulo(char *buf, size_t bufSize)
{
diff --git a/include/mcl/fp_tower.hpp b/include/mcl/fp_tower.hpp
index 8267caa..9a83442 100644
--- a/include/mcl/fp_tower.hpp
+++ b/include/mcl/fp_tower.hpp
@@ -202,7 +202,6 @@ class Fp2T : public fp::Serializable<Fp2T<_Fp>,
typedef fp::Unit Unit;
typedef FpDblT<Fp> FpDbl;
typedef Fp2DblT<Fp> Fp2Dbl;
- static uint32_t xi_a_;
static const size_t gN = 5;
/*
g = xi^((p - 1) / 6)
@@ -373,12 +372,12 @@ public:
}
}
- static uint32_t get_xi_a() { return xi_a_; }
- static void init(uint32_t xi_a)
+ static uint32_t get_xi_a() { return Fp::getOp().xi_a; }
+ static void init()
{
// assert(Fp::maxSize <= 256);
- xi_a_ = xi_a;
mcl::fp::Op& op = Fp::op_;
+ assert(op.xi_a);
add = (void (*)(Fp2T& z, const Fp2T& x, const Fp2T& y))op.fp2_addA_;
if (add == 0) add = fp2_addC;
sub = (void (*)(Fp2T& z, const Fp2T& x, const Fp2T& y))op.fp2_subA_;
@@ -402,7 +401,7 @@ public:
sqr = (void (*)(Fp2T& y, const Fp2T& x))op.fp2_sqrA_;
if (sqr == 0) sqr = fp2_sqrC;
op.fp2_inv = fp2_invW;
- if (xi_a == 1) {
+ if (op.xi_a == 1) {
/*
current fp_generator.hpp generates mul_xi for xi_a = 1
*/
@@ -417,7 +416,7 @@ public:
FpDblT<Fp>::init();
Fp2DblT<Fp>::init();
// call init before Fp2::pow because FpDbl is used in Fp2T
- const Fp2T xi(xi_a, 1);
+ const Fp2T xi(op.xi_a, 1);
const mpz_class& p = Fp::getOp().mp;
Fp2T::pow(g[0], xi, (p - 1) / 6); // g = xi^((p-1)/6)
for (size_t i = 1; i < gN; i++) {
@@ -579,9 +578,9 @@ private:
const Fp& a = x.a;
const Fp& b = x.b;
Fp t;
- Fp::mulUnit(t, a, xi_a_);
+ Fp::mulUnit(t, a, Fp::getOp().xi_a);
t -= b;
- Fp::mulUnit(y.b, b, xi_a_);
+ Fp::mulUnit(y.b, b, Fp::getOp().xi_a);
y.b += a;
y.a = t;
}
@@ -765,7 +764,6 @@ struct Fp2DblT {
template<class Fp> void (*Fp2DblT<Fp>::mulPre)(Fp2DblT&, const Fp2T<Fp>&, const Fp2T<Fp>&);
template<class Fp> void (*Fp2DblT<Fp>::sqrPre)(Fp2DblT&, const Fp2T<Fp>&);
-template<class Fp> uint32_t Fp2T<Fp>::xi_a_;
template<class Fp> Fp2T<Fp> Fp2T<Fp>::g[Fp2T<Fp>::gN];
template<class Fp> Fp2T<Fp> Fp2T<Fp>::g2[Fp2T<Fp>::gN];
template<class Fp> Fp2T<Fp> Fp2T<Fp>::g3[Fp2T<Fp>::gN];
diff --git a/src/fp_generator.hpp b/src/fp_generator.hpp
index 58140e9..dfb73d6 100644
--- a/src/fp_generator.hpp
+++ b/src/fp_generator.hpp
@@ -203,22 +203,6 @@ struct FpGenerator : Xbyak::CodeGenerator {
int pn_;
int FpByte_;
bool isFullBit_;
- // add/sub without carry. return true if overflow
- typedef bool (*bool3op)(uint64_t*, const uint64_t*, const uint64_t*);
-
- // add/sub with mod
-// typedef void (*void3op)(uint64_t*, const uint64_t*, const uint64_t*);
-
- // mul without carry. return top of z
- typedef uint64_t (*uint3opI)(uint64_t*, const uint64_t*, uint64_t);
-
- // neg
- typedef void (*void2op)(uint64_t*, const uint64_t*);
-
- // preInv
- typedef int (*int2op)(uint64_t*, const uint64_t*);
- void4u mul_;
-// uint3opI mulUnit_;
/*
@param op [in] ; use op.p, op.N, op.isFullBit
@@ -253,8 +237,6 @@ struct FpGenerator : Xbyak::CodeGenerator {
, rp_(0)
, pn_(0)
, FpByte_(0)
- , mul_(0)
-// , mulUnit_(0)
{
useMulx_ = cpu_.has(Xbyak::util::Cpu::tBMI2);
useAdx_ = cpu_.has(Xbyak::util::Cpu::tADX);
@@ -264,6 +246,7 @@ struct FpGenerator : Xbyak::CodeGenerator {
reset(); // reset jit code for reuse
setProtectModeRW(); // read/write memory
init_inner(op);
+ printf("code size=%d\n", (int)getSize());
setProtectModeRE(); // set read/exec memory
}
private:
@@ -271,10 +254,6 @@ private:
{
op_ = &op;
if (!cpu_.has(Xbyak::util::Cpu::tAVX)) return;
- /*
- first 4096-byte is data area
- remain is code area
- */
L(pL_);
p_ = reinterpret_cast<const uint64_t*>(getCurr());
for (size_t i = 0; i < op.N; i++) {
@@ -285,9 +264,7 @@ private:
FpByte_ = int(op.maxN * sizeof(uint64_t));
isFullBit_ = op.isFullBit;
// printf("p=%p, pn_=%d, isFullBit_=%d\n", p_, pn_, isFullBit_);
- // code from here
- setSize(4096);
- assert((getCurr<size_t>() & 4095) == 0);
+
op.fp_addPre = gen_addSubPre(true, pn_);
op.fp_subPre = gen_addSubPre(false, pn_);
op.fp_subA_ = gen_fp_sub();
@@ -297,42 +274,30 @@ private:
op.fp_negA_ = gen_fp_neg();
- void* func = 0;
- // setup fp_tower
- op.fp2_mulNF = 0;
- func = gen_fpDbl_add();
- if (func) op.fpDbl_addA_ = reinterpret_cast<void3u>(func);
- func = gen_fpDbl_sub();
- if (func) op.fpDbl_subA_ = reinterpret_cast<void3u>(func);
+ op.fpDbl_addA_ = gen_fpDbl_add();
+ op.fpDbl_subA_ = gen_fpDbl_sub();
op.fpDbl_addPre = gen_addSubPre(true, pn_ * 2);
op.fpDbl_subPre = gen_addSubPre(false, pn_ * 2);
- func = gen_fpDbl_mulPre();
- if (func) op.fpDbl_mulPreA_ = reinterpret_cast<void3u>(func);
-
- func = gen_fpDbl_mod(op);
- if (func) op.fpDbl_modA_ = reinterpret_cast<void2u>(func);
-
- func = gen_fpDbl_sqrPre(op);
- if (func) op.fpDbl_sqrPreA_ = reinterpret_cast<void2u>(func);
+ op.fpDbl_mulPreA_ = gen_fpDbl_mulPre();
+ op.fpDbl_sqrPreA_ = gen_fpDbl_sqrPre();
+ op.fpDbl_modA_ = gen_fpDbl_mod(op);
- func = gen_mul();
- if (func) {
- op.fp_mul = reinterpret_cast<void4u>(func); // used in toMont/fromMont
- op.fp_mulA_ = reinterpret_cast<void3u>(func);
- }
- func = gen_sqr();
- if (func) {
- op.fp_sqrA_ = reinterpret_cast<void2u>(func);
+ op.fp_mulA_ = gen_mul();
+ if (op.fp_mulA_) {
+ op.fp_mul = reinterpret_cast<void4u>(op.fp_mulA_); // used in toMont/fromMont
}
+ op.fp_sqrA_ = gen_sqr();
if (op.primeMode != PM_NIST_P192 && op.N <= 4) { // support general op.N but not fast for op.N > 4
align(16);
op.fp_preInv = getCurr<int2u>();
gen_preInv();
}
+ if (op.xi_a == 0) return; // Fp2 is not used
op.fp2_addA_ = gen_fp2_add();
op.fp2_subA_ = gen_fp2_sub();
op.fp2_negA_ = gen_fp2_neg();
+ op.fp2_mulNF = 0;
op.fp2Dbl_mulPreA_ = gen_fp2Dbl_mulPre();
op.fp2Dbl_sqrPreA_ = gen_fp2Dbl_sqrPre();
op.fp2_mulA_ = gen_fp2_mul();
@@ -668,10 +633,10 @@ private:
outLocalLabel();
return func;
}
- void* gen_fpDbl_add()
+ void3u gen_fpDbl_add()
{
align(16);
- void* func = getCurr<void*>();
+ void3u func = getCurr<void3u>();
if (pn_ <= 4) {
int tn = pn_ * 2 + (isFullBit_ ? 1 : 0);
StackFrame sf(this, 3, tn);
@@ -696,10 +661,10 @@ private:
}
return 0;
}
- void* gen_fpDbl_sub()
+ void3u gen_fpDbl_sub()
{
align(16);
- void* func = getCurr<void*>();
+ void3u func = getCurr<void3u>();
if (pn_ <= 4) {
int tn = pn_ * 2;
StackFrame sf(this, 3, tn);
@@ -800,10 +765,10 @@ private:
mov(ptr [pz + (pn_ - 1) * 8], *t0);
return func;
}
- void* gen_mul()
+ void3u gen_mul()
{
align(16);
- void* func = getCurr<void*>();
+ void3u func = getCurr<void3u>();
if (op_->primeMode == PM_NIST_P192) {
StackFrame sf(this, 3, 10 | UseRDX, 8 * 6);
mulPre3(rsp, sf.p[1], sf.p[2], sf.t);
@@ -1113,10 +1078,10 @@ private:
vmovq(z, xm0);
store_mr(z, Pack(t10, t9, t8, t4));
}
- void* gen_fpDbl_mod(const fp::Op& op)
+ void2u gen_fpDbl_mod(const fp::Op& op)
{
align(16);
- void* func = getCurr<void*>();
+ void2u func = getCurr<void2u>();
if (op.primeMode == PM_NIST_P192) {
StackFrame sf(this, 2, 6 | UseRDX);
fpDbl_mod_NIST_P192(sf.p[0], sf.p[1], sf.t);
@@ -1159,10 +1124,10 @@ private:
}
return 0;
}
- void* gen_sqr()
+ void2u gen_sqr()
{
align(16);
- void* func = getCurr<void*>();
+ void2u func = getCurr<void2u>();
if (op_->primeMode == PM_NIST_P192) {
StackFrame sf(this, 3, 10 | UseRDX, 6 * 8);
Pack t = sf.t;
@@ -2267,10 +2232,10 @@ private:
vmovq(z, xm0);
store_mr(z, zp);
}
- void* gen_fpDbl_sqrPre(const fp::Op&/* op */)
+ void2u gen_fpDbl_sqrPre()
{
align(16);
- void* func = getCurr<void*>();
+ void2u func = getCurr<void2u>();
if (pn_ == 2 && useMulx_) {
StackFrame sf(this, 2, 7 | UseRDX);
sqrPre2(sf.p[0], sf.p[1], sf.t);
@@ -2308,10 +2273,10 @@ private:
return func;
#endif
}
- void* gen_fpDbl_mulPre()
+ void3u gen_fpDbl_mulPre()
{
align(16);
- void* func = getCurr<void*>();
+ void3u func = getCurr<void3u>();
if (pn_ == 2 && useMulx_) {
StackFrame sf(this, 3, 5 | UseRDX);
mulPre2(sf.p[0], sf.p[1], sf.p[2], sf.t);