aboutsummaryrefslogtreecommitdiffstats
path: root/include/mcl/gmp_util.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/mcl/gmp_util.hpp')
-rw-r--r--include/mcl/gmp_util.hpp162
1 files changed, 120 insertions, 42 deletions
diff --git a/include/mcl/gmp_util.hpp b/include/mcl/gmp_util.hpp
index 399041c..bb461af 100644
--- a/include/mcl/gmp_util.hpp
+++ b/include/mcl/gmp_util.hpp
@@ -64,12 +64,13 @@ typedef mpz_class ImplType;
// z = [buf[n-1]:..:buf[1]:buf[0]]
// eg. buf[] = {0x12345678, 0xaabbccdd}; => z = 0xaabbccdd12345678;
template<class T>
-void setArray(mpz_class& z, const T *buf, size_t n)
+void setArray(bool *pb, mpz_class& z, const T *buf, size_t n)
{
#ifdef MCL_USE_VINT
- z.setArray(buf, n);
+ z.setArray(pb, buf, n);
#else
mpz_import(z.get_mpz_t(), n, -1, sizeof(*buf), 0, 0, buf);
+ *pb = true;
#endif
}
/*
@@ -78,44 +79,43 @@ void setArray(mpz_class& z, const T *buf, size_t n)
*/
#ifndef MCL_USE_VINT
template<class T>
-void getArray(T *buf, size_t maxSize, const mpz_srcptr x)
+bool getArray_(T *buf, size_t maxSize, const mpz_srcptr x)
{
const size_t bufByteSize = sizeof(T) * maxSize;
const int xn = x->_mp_size;
- if (xn < 0) throw cybozu::Exception("gmp:getArray:x is negative");
+ if (xn < 0) return false;
size_t xByteSize = sizeof(*x->_mp_d) * xn;
- if (xByteSize > bufByteSize) throw cybozu::Exception("gmp:getArray:too small") << xn << maxSize;
+ if (xByteSize > bufByteSize) return false;
memcpy(buf, x->_mp_d, xByteSize);
memset((char*)buf + xByteSize, 0, bufByteSize - xByteSize);
+ return true;
}
#endif
template<class T>
-void getArray(T *buf, size_t maxSize, const mpz_class& x)
+void getArray(bool *pb, T *buf, size_t maxSize, const mpz_class& x)
{
#ifdef MCL_USE_VINT
- x.getArray(buf, maxSize);
+ x.getArray(pb, buf, maxSize);
#else
- getArray(buf, maxSize, x.get_mpz_t());
+ *pb = getArray_(buf, maxSize, x.get_mpz_t());
#endif
}
inline void set(mpz_class& z, uint64_t x)
{
- setArray(z, &x, 1);
+ bool b;
+ setArray(&b, z, &x, 1);
+ assert(b);
+ (void)b;
}
-inline void setStr(bool *pb, mpz_class& z, const char *str, size_t strSize, int base = 0)
+inline void setStr(bool *pb, mpz_class& z, const char *str, int base = 0)
{
#ifdef MCL_USE_VINT
- z.setStr(pb, str, strSize, base);
+ z.setStr(pb, str, base);
#else
- *pb = z.set_str(std::string(str, strSize), base) == 0;
+ *pb = z.set_str(str, base) == 0;
#endif
}
-inline void setStr(mpz_class& z, const std::string& str, int base = 0)
-{
- bool b;
- setStr(&b, z, str.c_str(), str.size(), base);
- if (!b) throw cybozu::Exception("gmp:setStr");
-}
+
/*
set buf with string terminated by '\0'
return strlen(buf) if success else 0
@@ -125,18 +125,19 @@ inline size_t getStr(char *buf, size_t bufSize, const mpz_class& z, int base = 1
#ifdef MCL_USE_VINT
return z.getStr(buf, bufSize, base);
#else
- std::string str = z.get_str(base);
- if (str.size() < bufSize) {
- memcpy(buf, str.c_str(), str.size() + 1);
- return str.size();
- }
- return 0;
+ __gmp_alloc_cstring tmp(mpz_get_str(0, base, z.get_mpz_t()));
+ size_t n = strlen(tmp.str);
+ if (n + 1 > bufSize) return 0;
+ memcpy(buf, tmp.str, n + 1);
+ return n;
#endif
}
+
+#ifndef CYBOZU_DONT_USE_STRING
inline void getStr(std::string& str, const mpz_class& z, int base = 10)
{
#ifdef MCL_USE_VINT
- str = z.getStr(base);
+ z.getStr(str, base);
#else
str = z.get_str(base);
#endif
@@ -144,9 +145,11 @@ inline void getStr(std::string& str, const mpz_class& z, int base = 10)
inline std::string getStr(const mpz_class& z, int base = 10)
{
std::string s;
- getStr(s, z, base);
+ gmp::getStr(s, z, base);
return s;
}
+#endif
+
inline void add(mpz_class& z, const mpz_class& x, const mpz_class& y)
{
#ifdef MCL_USE_VINT
@@ -365,11 +368,12 @@ inline int legendre(const mpz_class& a, const mpz_class& p)
return mpz_legendre(a.get_mpz_t(), p.get_mpz_t());
#endif
}
-inline bool isPrime(const mpz_class& x)
+inline bool isPrime(bool *pb, const mpz_class& x)
{
#ifdef MCL_USE_VINT
- return x.isPrime(32);
+ return x.isPrime(pb, 32);
#else
+ *pb = true;
return mpz_probab_prime_p(x.get_mpz_t(), 32) != 0;
#endif
}
@@ -438,7 +442,7 @@ inline mpz_class abs(const mpz_class& x)
#endif
}
-inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen())
+inline void getRand(bool *pb, mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen())
{
if (rg.isZero()) rg = fp::RandGen::get();
assert(bitSize > 1);
@@ -447,7 +451,7 @@ inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()
uint32_t buf[128];
assert(n <= CYBOZU_NUM_OF_ARRAY(buf));
if (n > CYBOZU_NUM_OF_ARRAY(buf)) {
- z = 0;
+ *pb = false;
return;
}
rg.read(buf, n * sizeof(buf[0]));
@@ -459,22 +463,26 @@ inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()
v |= 1U << (rem - 1);
}
buf[n - 1] = v;
- setArray(z, buf, n);
+ setArray(pb, z, buf, n);
}
-inline void getRandPrime(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false)
+inline void getRandPrime(bool *pb, mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false)
{
if (rg.isZero()) rg = fp::RandGen::get();
assert(bitSize > 2);
- do {
- getRand(z, bitSize, rg);
+ for (;;) {
+ getRand(pb, z, bitSize, rg);
+ if (!*pb) return;
if (setSecondBit) {
z |= mpz_class(1) << (bitSize - 2);
}
if (mustBe3mod4) {
z |= 3;
}
- } while (!(isPrime(z)));
+ bool ret = isPrime(pb, z);
+ if (!*pb) return;
+ if (ret) return;
+ }
}
inline mpz_class getQuadraticNonResidue(const mpz_class& p)
{
@@ -566,6 +574,49 @@ bool getNAF(Vec& v, const mpz_class& x)
}
}
+#ifndef CYBOZU_DONT_USE_EXCEPTION
+inline void setStr(mpz_class& z, const std::string& str, int base = 0)
+{
+ bool b;
+ setStr(&b, z, str.c_str(), base);
+ if (!b) throw cybozu::Exception("gmp:setStr");
+}
+template<class T>
+void setArray(mpz_class& z, const T *buf, size_t n)
+{
+ bool b;
+ setArray(&b, z, buf, n);
+ if (!b) throw cybozu::Exception("gmp:setArray");
+}
+template<class T>
+void getArray(T *buf, size_t maxSize, const mpz_class& x)
+{
+ bool b;
+ getArray(&b, buf, maxSize, x);
+ if (!b) throw cybozu::Exception("gmp:getArray");
+}
+inline bool isPrime(const mpz_class& x)
+{
+ bool b;
+ bool ret = isPrime(&b, x);
+ if (!b) throw cybozu::Exception("gmp:isPrime");
+ return ret;
+}
+inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen())
+{
+ bool b;
+ getRand(&b, z, bitSize, rg);
+ if (!b) throw cybozu::Exception("gmp:getRand");
+}
+inline void getRandPrime(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false)
+{
+ bool b;
+ getRandPrime(&b, z, bitSize, rg, setSecondBit, mustBe3mod4);
+ if (!b) throw cybozu::Exception("gmp:getRandPrime");
+}
+#endif
+
+
} // mcl::gmp
/*
@@ -591,12 +642,19 @@ public:
s = 0;
q_add_1_div_2 = 0;
}
- void set(const mpz_class& _p)
+ void set(bool *pb, const mpz_class& _p)
{
p = _p;
- if (p <= 2) throw cybozu::Exception("SquareRoot:bad p") << p;
- isPrime = gmp::isPrime(p);
- if (!isPrime) return; // don't throw until get() is called
+ if (p <= 2) {
+ *pb = false;
+ return;
+ }
+ isPrime = gmp::isPrime(pb, p);
+ if (!*pb) return;
+ if (!isPrime) {
+ *pb = false;
+ return;
+ }
g = gmp::getQuadraticNonResidue(p);
// p - 1 = 2^r q, q is odd
r = 0;
@@ -607,13 +665,18 @@ public:
}
gmp::powMod(s, g, q, p);
q_add_1_div_2 = (q + 1) / 2;
+ *pb = true;
}
/*
solve x^2 = a mod p
*/
- bool get(mpz_class& x, const mpz_class& a) const
+ bool get(bool *pb, mpz_class& x, const mpz_class& a) const
{
- if (!isPrime) throw cybozu::Exception("SquareRoot:get:not prime") << p;
+ if (!isPrime) {
+ *pb = false;
+ return false;
+ }
+ *pb = true;
if (a == 0) {
x = 0;
return true;
@@ -653,7 +716,7 @@ public:
template<class Fp>
bool get(Fp& x, const Fp& a) const
{
- if (Fp::getOp().mp != p) throw cybozu::Exception("bad Fp") << Fp::getOp().mp << p;
+ assert(Fp::getOp().mp == p);
if (a == 0) {
x = 0;
return true;
@@ -691,6 +754,21 @@ public:
}
return true;
}
+#ifndef CYBOZU_DONT_USE_EXCEPTION
+ void set(const mpz_class& _p)
+ {
+ bool b;
+ set(&b, _p);
+ if (!b) throw cybozu::Exception("gmp:SquareRoot:set");
+ }
+ bool get(mpz_class& x, const mpz_class& a) const
+ {
+ bool b;
+ bool ret = get(&b, x, a);
+ if (!b) throw cybozu::Exception("gmp:SquareRoot:get:not prime");
+ return ret;
+ }
+#endif
};
} // mcl