aboutsummaryrefslogtreecommitdiffstats
path: root/include/mcl/operator.hpp
blob: 3acd6c6727ac48185767d2c3cdbf82a9f76ee8f2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#pragma once
/**
    @file
    @brief operator class
    @author MITSUNARI Shigeo(@herumi)
    @license modified new BSD license
    http://opensource.org/licenses/BSD-3-Clause
*/
#include <mcl/op.hpp>
#include <mcl/util.hpp>
#ifdef _MSC_VER
    #ifndef MCL_FORCE_INLINE
        #define MCL_FORCE_INLINE __forceinline
    #endif
    #pragma warning(push)
    #pragma warning(disable : 4714)
#else
    #ifndef MCL_FORCE_INLINE
        #define MCL_FORCE_INLINE __attribute__((always_inline))
    #endif
#endif

namespace mcl { namespace fp {

template<class T>
struct Empty {};

/*
    T must have add, sub, mul, inv, neg
*/
template<class T, class E = Empty<T> >
struct Operator : E {
    template<class S> MCL_FORCE_INLINE T& operator+=(const S& rhs) { T::add(static_cast<T&>(*this), static_cast<const T&>(*this), rhs); return static_cast<T&>(*this); }
    template<class S> MCL_FORCE_INLINE T& operator-=(const S& rhs) { T::sub(static_cast<T&>(*this), static_cast<const T&>(*this), rhs); return static_cast<T&>(*this); }
    template<class S> friend MCL_FORCE_INLINE T operator+(const T& a, const S& b) { T c; T::add(c, a, b); return c; }
    template<class S> friend MCL_FORCE_INLINE T operator-(const T& a, const S& b) { T c; T::sub(c, a, b); return c; }
    template<class S> MCL_FORCE_INLINE T& operator*=(const S& rhs) { T::mul(static_cast<T&>(*this), static_cast<const T&>(*this), rhs); return static_cast<T&>(*this); }
    template<class S> friend MCL_FORCE_INLINE T operator*(const T& a, const S& b) { T c; T::mul(c, a, b); return c; }
    MCL_FORCE_INLINE T& operator/=(const T& rhs) { T c; T::inv(c, rhs); T::mul(static_cast<T&>(*this), static_cast<const T&>(*this), c); return static_cast<T&>(*this); }
    static MCL_FORCE_INLINE void div(T& c, const T& a, const T& b) { T t; T::inv(t, b); T::mul(c, a, t); }
    friend MCL_FORCE_INLINE T operator/(const T& a, const T& b) { T c; T::inv(c, b); c *= a; return c; }
    MCL_FORCE_INLINE T operator-() const { T c; T::neg(c, static_cast<const T&>(*this)); return c; }
    template<class tag2, size_t maxBitSize2, template<class _tag, size_t _maxBitSize> class FpT>
    static void pow(T& z, const T& x, const FpT<tag2, maxBitSize2>& y)
    {
        fp::Block b;
        y.getBlock(b);
        powArray(z, x, b.p, b.n, false, false);
    }
    template<class tag2, size_t maxBitSize2, template<class _tag, size_t _maxBitSize> class FpT>
    static void powGeneric(T& z, const T& x, const FpT<tag2, maxBitSize2>& y)
    {
        fp::Block b;
        y.getBlock(b);
        powArrayBase(z, x, b.p, b.n, false, false);
    }
    template<class tag2, size_t maxBitSize2, template<class _tag, size_t _maxBitSize> class FpT>
    static void powCT(T& z, const T& x, const FpT<tag2, maxBitSize2>& y)
    {
        fp::Block b;
        y.getBlock(b);
        powArray(z, x, b.p, b.n, false, true);
    }
    static void pow(T& z, const T& x, int64_t y)
    {
        const uint64_t u = std::abs(y);
#if MCL_SIZEOF_UNIT == 8
        powArray(z, x, &u, 1, y < 0, false);
#else
        uint32_t ua[2] = { uint32_t(u), uint32_t(u >> 32) };
        size_t un = ua[1] ? 2 : 1;
        powArray(z, x, ua, un, y < 0, false);
#endif
    }
    static void pow(T& z, const T& x, const mpz_class& y)
    {
        powArray(z, x, gmp::getUnit(y), gmp::getUnitSize(y), y < 0, false);
    }
    static void powGeneric(T& z, const T& x, const mpz_class& y)
    {
        powArrayBase(z, x, gmp::getUnit(y), gmp::getUnitSize(y), y < 0, false);
    }
    static void powCT(T& z, const T& x, const mpz_class& y)
    {
        powArray(z, x, gmp::getUnit(y), gmp::getUnitSize(y), y < 0, true);
    }
    static void setPowArrayGLV(void f(T& z, const T& x, const Unit *y, size_t yn, bool isNegative, bool constTime))
    {
        powArrayGLV = f;
    }
private:
    static void (*powArrayGLV)(T& z, const T& x, const Unit *y, size_t yn, bool isNegative, bool constTime);
    static void powArray(T& z, const T& x, const Unit *y, size_t yn, bool isNegative, bool constTime)
    {
        if (powArrayGLV && (constTime || yn > 1)) {
            powArrayGLV(z, x, y, yn, isNegative, constTime);
            return;
        }
        powArrayBase(z, x, y, yn, isNegative, constTime);
    }
    static void powArrayBase(T& z, const T& x, const Unit *y, size_t yn, bool isNegative, bool constTime)
    {
        T tmp;
        const T *px = &x;
        if (&z == &x) {
            tmp = x;
            px = &tmp;
        }
        z = 1;
        fp::powGeneric(z, *px, y, yn, T::mul, T::sqr, (void (*)(T&, const T&))0, constTime ? T::BaseFp::getBitSize() : 0);
        if (isNegative) {
            T::inv(z, z);
        }
    }
};

template<class T, class E>
void (*Operator<T, E>::powArrayGLV)(T& z, const T& x, const Unit *y, size_t yn, bool isNegative, bool constTime);

} } // mcl::fp