diff options
author | MITSUNARI Shigeo <herumi@nifty.com> | 2017-07-18 14:47:15 +0800 |
---|---|---|
committer | MITSUNARI Shigeo <herumi@nifty.com> | 2017-07-18 14:47:15 +0800 |
commit | cd7c49476c5580681895d1fcedebf4053b26fe54 (patch) | |
tree | 5b5c4f7fa4f8cb3f7ad7b118c04b6f0f4d4fd49a | |
parent | 7a61c16f191f067e6f2d2a6cb04a5dd001b214dd (diff) | |
download | dexon-mcl-cd7c49476c5580681895d1fcedebf4053b26fe54.tar.gz dexon-mcl-cd7c49476c5580681895d1fcedebf4053b26fe54.tar.zst dexon-mcl-cd7c49476c5580681895d1fcedebf4053b26fe54.zip |
add Vint::powMod
-rw-r--r-- | include/mcl/util.hpp | 4 | ||||
-rw-r--r-- | include/mcl/vint.hpp | 133 | ||||
-rw-r--r-- | test/vint_test.cpp | 35 |
3 files changed, 92 insertions, 80 deletions
diff --git a/include/mcl/util.hpp b/include/mcl/util.hpp index eb16bb4..f1c2cfd 100644 --- a/include/mcl/util.hpp +++ b/include/mcl/util.hpp @@ -194,8 +194,8 @@ void getRandVal(T *out, RG& rg, const T *in, size_t bitSize) @param limitBit [in] const time version if the value is positive @note &out != x and out = the unit element of G */ -template<class G, class T> -void powGeneric(G& out, const G& x, const T *y, size_t n, void mul(G&, const G&, const G&) , void sqr(G&, const G&), void normalize(G&, const G&), size_t limitBit = 0) +template<class G, class Mul, class Sqr, class T> +void powGeneric(G& out, const G& x, const T *y, size_t n, const Mul& mul, const Sqr& sqr, void normalize(G&, const G&), size_t limitBit = 0) { assert(&out != &x); G tbl[4]; // tbl = { discard, x, x^2, x^3 } diff --git a/include/mcl/vint.hpp b/include/mcl/vint.hpp index 3e1ed4b..9cfa6d4 100644 --- a/include/mcl/vint.hpp +++ b/include/mcl/vint.hpp @@ -11,6 +11,7 @@ #include <assert.h> #include <cmath> #include <iostream> +#include <mcl/util.hpp> #ifndef MCL_VINT_UNIT_BYTE_SIZE #define MCL_VINT_UNIT_BYTE_SIZE 4 @@ -862,6 +863,22 @@ private: } r.trim(xn); } + struct MulMod { + const VintT *pm; + void operator()(VintT& z, const VintT& x, const VintT& y) const + { + VintT::mul(z, x, y); + z %= *pm; + } + }; + struct SqrMod { + const VintT *pm; + void operator()(VintT& y, const VintT& x) const + { + VintT::sqr(y, x); + y %= *pm; + } + }; public: VintT(int x = 0) : size_(0) @@ -882,6 +899,15 @@ public: size_ = 1; return *this; } + void swap(VintT& rhs) +#if CYBOZU_CPP_VERSION >= CYBOZU_CPP_VERSION_CPP11 + noexcept +#endif + { + std::swap(buf_, rhs.buf_); + std::swap(size_, rhs.size_); + std::swap(isNeg_, rhs.isNeg_); + } /* set positive value @note assume little endian system @@ -1064,6 +1090,10 @@ public: z.isNeg_ = x.isNeg_ ^ y.isNeg_; z.trim(zn); } + static void sqr(VintT& y, const VintT& x) + { + mul(y, x, x); + } static void add1(VintT& z, const VintT& x, int y) { if (y == invalidVar) throw cybozu::Exception("VintT:add1:bad y"); @@ -1208,6 +1238,36 @@ public: if (&y != &x) { y = x; } y.isNeg_ = false; } + static void pow(VintT& z, const VintT& x, const VintT& y) + { + if (y.isNeg_) throw cybozu::Exception("Vint::pow:negative y") << y; + const VintT xx = x; + z = 1; + mcl::fp::powGeneric(z, x, &y.buf_[0], y.size(), mul, sqr, (void (*)(VintT&, const VintT&))0); + } + static void pow(VintT& z, const VintT& x, int y) + { + if (y < 0) throw cybozu::Exception("Vint::pow:negative y") << y; + const VintT xx = x; + Unit absY = std::abs(y); + z = 1; + mcl::fp::powGeneric(z, x, &absY, 1, mul, sqr, (void (*)(VintT&, const VintT&))0); + } + /* + z = x ^ y mod m + */ + static void powMod(VintT& z, const VintT& x, const VintT& y, const VintT& m) + { + if (y.isNeg_) throw cybozu::Exception("Vint::pow:negative y") << y; + VintT zz = 1; + MulMod mulMod; + SqrMod sqrMod; + mulMod.pm = &m; + sqrMod.pm = &m; + zz = 1; + mcl::fp::powGeneric(zz, x, &y.buf_[0], y.size(), mulMod, sqrMod, (void (*)(VintT&, const VintT&))0); + z.swap(zz); + } VintT& operator++() { add(*this, *this, 1); return *this; } VintT& operator--() { sub(*this, *this, 1); return *this; } VintT operator++(int) { VintT c = *this; add(*this, *this, 1); return c; } @@ -1241,79 +1301,6 @@ public: VintT operator>>(size_t n) const { VintT c = *this; c >>= n; return c; } }; -namespace util { -/* - dispatch Uint, int, size_t, and so on -*/ -template<class T> -struct IntTag { - typedef typename T::Unit Unit; - static Unit getBlock(const T& x, size_t i) - { - return x.getUnit()[i]; - } - static size_t getBlockSize(const T& x) - { - return x.size(); - } -}; - -template<> -struct IntTag<int> { - typedef int Unit; - static Unit getBlock(const int& x, size_t) - { - return x; - } - static size_t getBlockSize(const int&) - { - return 1; - } -}; -template<> -struct IntTag<size_t> { - typedef size_t Unit; - static Unit getBlock(const size_t& x, size_t) - { - return x; - } - static size_t getBlockSize(const size_t&) - { - return 1; - } -}; - -} // util - -/** - return pow(x, y) -*/ -template<class T, class S> -T power(const T& x, const S& y) -{ - typedef typename mcl::util::IntTag<S> Tag; - typedef typename Tag::Unit Unit; - T t(x); - T out = 1; - for (size_t i = 0, n = Tag::getBlockSize(y); i < n; i++) { - Unit v = Tag::getBlock(y, i); - int m = (int)sizeof(Unit) * 8; - if (i == n - 1) { - // avoid unused multiplication - while (m > 0 && (v & (Unit(1) << (m - 1))) == 0) { - m--; - } - } - for (int j = 0; j < m; j++) { - if (v & (Unit(1) << j)) { - out *= t; - } - t *= t; - } - } - return out; -} - //typedef VintT<local::VariableBuffer<mcl::local::Unit> > Vint; //typedef VintT<local::FixedBuffer<mcl::local::Unit, 10> > Vint; typedef VintT<local::Buffer<mcl::local::Unit> > Vint; diff --git a/test/vint_test.cpp b/test/vint_test.cpp index 5099bcc..f8b8579 100644 --- a/test/vint_test.cpp +++ b/test/vint_test.cpp @@ -623,6 +623,7 @@ CYBOZU_TEST_AUTO(shift) Vint y, z; const size_t unitBitSize = Vint::unitBitSize; + Vint s; // shl for (size_t i = 1; i < 31; i++) { Vint::shl(y, x, i); @@ -636,7 +637,7 @@ CYBOZU_TEST_AUTO(shift) } for (int i = 0; i < 4; i++) { Vint::shl(y, x, i * unitBitSize); - Vint s = power(Vint(2), Vint(i * unitBitSize)); + Vint::pow(s, Vint(2), Vint(i * unitBitSize)); z = x * s; CYBOZU_TEST_EQUAL(y, z); y = x << (i * unitBitSize); @@ -647,7 +648,7 @@ CYBOZU_TEST_AUTO(shift) } for (int i = 0; i < 100; i++) { y = x << i; - Vint s = power(Vint(2), Vint(i)); + Vint::pow(s, Vint(2), i); z = x * s; CYBOZU_TEST_EQUAL(y, z); y = x; @@ -668,7 +669,7 @@ CYBOZU_TEST_AUTO(shift) } for (int i = 0; i < 3; i++) { Vint::shr(y, x, i * unitBitSize); - Vint s = power(Vint(2), Vint(i * unitBitSize)); + Vint::pow(s, Vint(2), i * unitBitSize); z = x / s; CYBOZU_TEST_EQUAL(y, z); y = x >> (i * unitBitSize); @@ -679,7 +680,7 @@ CYBOZU_TEST_AUTO(shift) } for (int i = 0; i < 100; i++) { y = x >> i; - Vint s = power(Vint(2), Vint(i)); + Vint::pow(s, Vint(2), i); z = x / s; CYBOZU_TEST_EQUAL(y, z); y = x; @@ -775,7 +776,7 @@ CYBOZU_TEST_AUTO(sample) x = 2; y = 250; - x = power(x, y); + Vint::pow(x, x, y); Vint r, q; r = x % y; q = x / y; @@ -995,3 +996,27 @@ CYBOZU_TEST_AUTO(T) x /= -1; CYBOZU_TEST_EQUAL(x, 3); } + +CYBOZU_TEST_AUTO(pow) +{ + Vint x = 2; + Vint y; + Vint::pow(y, x, 3); + CYBOZU_TEST_EQUAL(y, 8); + x = -2; + Vint::pow(y, x, 3); + CYBOZU_TEST_EQUAL(y, -8); + CYBOZU_TEST_EXCEPTION(Vint::pow(y, x, -2), std::exception); +} + +CYBOZU_TEST_AUTO(powMod) +{ + Vint x = 7; + Vint m = 65537; + Vint y; + Vint::powMod(y, x, 20, m); + CYBOZU_TEST_EQUAL(y, 55277); + Vint::powMod(y, x, m - 1, m); + CYBOZU_TEST_EQUAL(y, 1); + +} |