diff options
Diffstat (limited to 'include/mcl/gmp_util.hpp')
-rw-r--r-- | include/mcl/gmp_util.hpp | 162 |
1 files changed, 120 insertions, 42 deletions
diff --git a/include/mcl/gmp_util.hpp b/include/mcl/gmp_util.hpp index 399041c..bb461af 100644 --- a/include/mcl/gmp_util.hpp +++ b/include/mcl/gmp_util.hpp @@ -64,12 +64,13 @@ typedef mpz_class ImplType; // z = [buf[n-1]:..:buf[1]:buf[0]] // eg. buf[] = {0x12345678, 0xaabbccdd}; => z = 0xaabbccdd12345678; template<class T> -void setArray(mpz_class& z, const T *buf, size_t n) +void setArray(bool *pb, mpz_class& z, const T *buf, size_t n) { #ifdef MCL_USE_VINT - z.setArray(buf, n); + z.setArray(pb, buf, n); #else mpz_import(z.get_mpz_t(), n, -1, sizeof(*buf), 0, 0, buf); + *pb = true; #endif } /* @@ -78,44 +79,43 @@ void setArray(mpz_class& z, const T *buf, size_t n) */ #ifndef MCL_USE_VINT template<class T> -void getArray(T *buf, size_t maxSize, const mpz_srcptr x) +bool getArray_(T *buf, size_t maxSize, const mpz_srcptr x) { const size_t bufByteSize = sizeof(T) * maxSize; const int xn = x->_mp_size; - if (xn < 0) throw cybozu::Exception("gmp:getArray:x is negative"); + if (xn < 0) return false; size_t xByteSize = sizeof(*x->_mp_d) * xn; - if (xByteSize > bufByteSize) throw cybozu::Exception("gmp:getArray:too small") << xn << maxSize; + if (xByteSize > bufByteSize) return false; memcpy(buf, x->_mp_d, xByteSize); memset((char*)buf + xByteSize, 0, bufByteSize - xByteSize); + return true; } #endif template<class T> -void getArray(T *buf, size_t maxSize, const mpz_class& x) +void getArray(bool *pb, T *buf, size_t maxSize, const mpz_class& x) { #ifdef MCL_USE_VINT - x.getArray(buf, maxSize); + x.getArray(pb, buf, maxSize); #else - getArray(buf, maxSize, x.get_mpz_t()); + *pb = getArray_(buf, maxSize, x.get_mpz_t()); #endif } inline void set(mpz_class& z, uint64_t x) { - setArray(z, &x, 1); + bool b; + setArray(&b, z, &x, 1); + assert(b); + (void)b; } -inline void setStr(bool *pb, mpz_class& z, const char *str, size_t strSize, int base = 0) +inline void setStr(bool *pb, mpz_class& z, const char *str, int base = 0) { #ifdef MCL_USE_VINT - z.setStr(pb, str, strSize, base); + z.setStr(pb, str, base); #else - *pb = z.set_str(std::string(str, strSize), base) == 0; + *pb = z.set_str(str, base) == 0; #endif } -inline void setStr(mpz_class& z, const std::string& str, int base = 0) -{ - bool b; - setStr(&b, z, str.c_str(), str.size(), base); - if (!b) throw cybozu::Exception("gmp:setStr"); -} + /* set buf with string terminated by '\0' return strlen(buf) if success else 0 @@ -125,18 +125,19 @@ inline size_t getStr(char *buf, size_t bufSize, const mpz_class& z, int base = 1 #ifdef MCL_USE_VINT return z.getStr(buf, bufSize, base); #else - std::string str = z.get_str(base); - if (str.size() < bufSize) { - memcpy(buf, str.c_str(), str.size() + 1); - return str.size(); - } - return 0; + __gmp_alloc_cstring tmp(mpz_get_str(0, base, z.get_mpz_t())); + size_t n = strlen(tmp.str); + if (n + 1 > bufSize) return 0; + memcpy(buf, tmp.str, n + 1); + return n; #endif } + +#ifndef CYBOZU_DONT_USE_STRING inline void getStr(std::string& str, const mpz_class& z, int base = 10) { #ifdef MCL_USE_VINT - str = z.getStr(base); + z.getStr(str, base); #else str = z.get_str(base); #endif @@ -144,9 +145,11 @@ inline void getStr(std::string& str, const mpz_class& z, int base = 10) inline std::string getStr(const mpz_class& z, int base = 10) { std::string s; - getStr(s, z, base); + gmp::getStr(s, z, base); return s; } +#endif + inline void add(mpz_class& z, const mpz_class& x, const mpz_class& y) { #ifdef MCL_USE_VINT @@ -365,11 +368,12 @@ inline int legendre(const mpz_class& a, const mpz_class& p) return mpz_legendre(a.get_mpz_t(), p.get_mpz_t()); #endif } -inline bool isPrime(const mpz_class& x) +inline bool isPrime(bool *pb, const mpz_class& x) { #ifdef MCL_USE_VINT - return x.isPrime(32); + return x.isPrime(pb, 32); #else + *pb = true; return mpz_probab_prime_p(x.get_mpz_t(), 32) != 0; #endif } @@ -438,7 +442,7 @@ inline mpz_class abs(const mpz_class& x) #endif } -inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()) +inline void getRand(bool *pb, mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()) { if (rg.isZero()) rg = fp::RandGen::get(); assert(bitSize > 1); @@ -447,7 +451,7 @@ inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen() uint32_t buf[128]; assert(n <= CYBOZU_NUM_OF_ARRAY(buf)); if (n > CYBOZU_NUM_OF_ARRAY(buf)) { - z = 0; + *pb = false; return; } rg.read(buf, n * sizeof(buf[0])); @@ -459,22 +463,26 @@ inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen() v |= 1U << (rem - 1); } buf[n - 1] = v; - setArray(z, buf, n); + setArray(pb, z, buf, n); } -inline void getRandPrime(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false) +inline void getRandPrime(bool *pb, mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false) { if (rg.isZero()) rg = fp::RandGen::get(); assert(bitSize > 2); - do { - getRand(z, bitSize, rg); + for (;;) { + getRand(pb, z, bitSize, rg); + if (!*pb) return; if (setSecondBit) { z |= mpz_class(1) << (bitSize - 2); } if (mustBe3mod4) { z |= 3; } - } while (!(isPrime(z))); + bool ret = isPrime(pb, z); + if (!*pb) return; + if (ret) return; + } } inline mpz_class getQuadraticNonResidue(const mpz_class& p) { @@ -566,6 +574,49 @@ bool getNAF(Vec& v, const mpz_class& x) } } +#ifndef CYBOZU_DONT_USE_EXCEPTION +inline void setStr(mpz_class& z, const std::string& str, int base = 0) +{ + bool b; + setStr(&b, z, str.c_str(), base); + if (!b) throw cybozu::Exception("gmp:setStr"); +} +template<class T> +void setArray(mpz_class& z, const T *buf, size_t n) +{ + bool b; + setArray(&b, z, buf, n); + if (!b) throw cybozu::Exception("gmp:setArray"); +} +template<class T> +void getArray(T *buf, size_t maxSize, const mpz_class& x) +{ + bool b; + getArray(&b, buf, maxSize, x); + if (!b) throw cybozu::Exception("gmp:getArray"); +} +inline bool isPrime(const mpz_class& x) +{ + bool b; + bool ret = isPrime(&b, x); + if (!b) throw cybozu::Exception("gmp:isPrime"); + return ret; +} +inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()) +{ + bool b; + getRand(&b, z, bitSize, rg); + if (!b) throw cybozu::Exception("gmp:getRand"); +} +inline void getRandPrime(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false) +{ + bool b; + getRandPrime(&b, z, bitSize, rg, setSecondBit, mustBe3mod4); + if (!b) throw cybozu::Exception("gmp:getRandPrime"); +} +#endif + + } // mcl::gmp /* @@ -591,12 +642,19 @@ public: s = 0; q_add_1_div_2 = 0; } - void set(const mpz_class& _p) + void set(bool *pb, const mpz_class& _p) { p = _p; - if (p <= 2) throw cybozu::Exception("SquareRoot:bad p") << p; - isPrime = gmp::isPrime(p); - if (!isPrime) return; // don't throw until get() is called + if (p <= 2) { + *pb = false; + return; + } + isPrime = gmp::isPrime(pb, p); + if (!*pb) return; + if (!isPrime) { + *pb = false; + return; + } g = gmp::getQuadraticNonResidue(p); // p - 1 = 2^r q, q is odd r = 0; @@ -607,13 +665,18 @@ public: } gmp::powMod(s, g, q, p); q_add_1_div_2 = (q + 1) / 2; + *pb = true; } /* solve x^2 = a mod p */ - bool get(mpz_class& x, const mpz_class& a) const + bool get(bool *pb, mpz_class& x, const mpz_class& a) const { - if (!isPrime) throw cybozu::Exception("SquareRoot:get:not prime") << p; + if (!isPrime) { + *pb = false; + return false; + } + *pb = true; if (a == 0) { x = 0; return true; @@ -653,7 +716,7 @@ public: template<class Fp> bool get(Fp& x, const Fp& a) const { - if (Fp::getOp().mp != p) throw cybozu::Exception("bad Fp") << Fp::getOp().mp << p; + assert(Fp::getOp().mp == p); if (a == 0) { x = 0; return true; @@ -691,6 +754,21 @@ public: } return true; } +#ifndef CYBOZU_DONT_USE_EXCEPTION + void set(const mpz_class& _p) + { + bool b; + set(&b, _p); + if (!b) throw cybozu::Exception("gmp:SquareRoot:set"); + } + bool get(mpz_class& x, const mpz_class& a) const + { + bool b; + bool ret = get(&b, x, a); + if (!b) throw cybozu::Exception("gmp:SquareRoot:get:not prime"); + return ret; + } +#endif }; } // mcl |