[模板] 快速数论变换 (NTT)

简介

在之前我们已经介绍了快速傅里叶变换,利用单位复根,我们可以在 O(n log n)的时间内进行 DFTIDFT 变换。

但是由于不可避免的精度问题,单位复根存在一定的局限性,因此对于正整数的卷积运算,我们通常使用快速数论变换来避免精度问题。

原根

原根的定义

对于一个正整数 p,若存在一个数 g 满足 (p,g)=1,并且 δp(g) = φ(p),其中 δp(g) 为使 gd1 (mod p) 的最小正整数 d,称为为 g 是模 p 的原根。

由欧拉定理可知,δp(g) 一定小于等于 φ(p)

原根的计算

我们可以通过枚举一个数 g,并检验 g 是不是 p 的原根

NTT 中,我们的模数 p 需要是一个质数,因此 φ(p)=p1,根据定义,δp(g)=φ(p)=p1,因此对于原根 g

i[1,p2]gi1

这样的复杂度是 O(p) 的,我们考虑如何优化

对于一个数 gδp(g) 一定是 p1 的约数

假设存在最小的 d 不是 p1 的约数,那么可以找到一个 x 满足 xd>p1>(x1)d

gdxgdxgp11(mod p)

那么就存在一个更小的 dx(p1)d ,与假设相反。

那么我们只需要枚举 p1 中除 p1 的所有约数 q

gq1(mod p)

实际上,我们还可以优化,将 p1 分解质因数

p1=i=1rpiki

我们只需要判断

i[1,r]gp1pi1(mod p)

因为对于一个更小的约数 q,如果它已经同余 1 了,因为一定存在至少一个 pi 满足 qp1pi,那么它一定能通过一个幂运算使得 gp1pi 也同余 1

inline int root(int x){
    for(register int i=2; i<=x; i++){
        int tmp = x - 1;

        bool flag = true;

        for(register int k=2; k * k <= (x - 1); k++){
            if(tmp % k == 0){
                if(powx(i, (x - 1) / k, x) == 1){
                    flag = false;
                    break;
                }

                while(tmp % k == 0)
                    tmp /= k;
            }
        }

        if(flag && (tmp == 1 || powx(i, (x - 1) / tmp, x) != 1)){
            return i;
        }
    }
}
C++

原根的性质

我们还需要原根拥有和单位复根一样的性质来进行 DFTIDFT 变换。

p=qn+1 其中 n2 的幂

性质一

ωn=gq,那么 1,gq,g2q,,g(n1)q互不相同,满足单位复根的性质一

性质二

ωn=pq,那么 ω2n=pq2(p=q2×2n+1),所以 ω2n2k=ωnk,满足单位复根的性质二

性质三

因为

ωnngp11(mod p)

所以

ωnn2±1(mod p)

又因为

ωnn2ωn0(mod p)

所以

ωnn21(mod p)


ωnk+n2ωnk(mod p)

满足单位复根的性质三

性质四

S(ωnk)=1+ωnk+(ωnk)2++(ωnk)n1=1(ωnk)n1ωnk=(ωnk)n1ωnk1

由性质三

(ωnk)n10(mod p)

那么 S(ωnk)=0,满足单位复根的性质四

快速数论变换

接下来的操作和快速傅里叶变换一样,只是每一步的操作都需要取模,在最后输出结果时,由除以 n 变为乘以 n1

NTT 常用的模数为 9982443531004535809 ,它们的原根为 3

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MOD = 998244353;
const int MAXN = 4194304 + 10;

inline long long powx(long long a, long long b, long long mod){
    long long ans = 1;

    for(; b; b >>= 1){
        if(b & 1) ans *= a, ans %= mod;
        a *= a, a %= mod;
    }

    return ans;
}

inline long long exgcd(long long a, long long b, long long &x, long long &y){
    if(!b){
        x = 1, y = 0;
        return a;
    }

    long long d = exgcd(b, a%b, y, x);
    y -= x * (a / b);

    return d;
}

inline long long inv(long long a, long long mod){
    long long x, y;
    exgcd(a, mod, x, y);
    return (x + mod) % mod;
}

inline int root(int x){
    for(register int i=2; i<=x; i++){
        int tmp = x - 1;

        bool flag = true;

        for(register int k=2; k * k <= (x - 1); k++){
            if(tmp % k == 0){
                if(powx(i, (x - 1) / k, x) == 1){
                    flag = false;
                    break;
                }

                while(tmp % k == 0)
                    tmp /= k;
            }
        }

        if(flag && (tmp == 1 || powx(i, (x - 1) / tmp, x) != 1)){
            return i;
        }
    }
    throw;
}

namespace NTT{
    int place[MAXN];
    long long omega[MAXN];
    long long omegaInverse[MAXN];

    inline void init(int n){
        int k = 0;

        while((1 << k) < n)
            k++;

        for(register int i=0; i<n; i++){
            for(register int j=0; j<k; j++){
                if(i & (1 << j)){
                    place[i] |= 1 << (k - j - 1);
                }
            }
        }

        long long g = root(MOD);
        long long tmp = powx(g, (MOD - 1) / n, MOD);

        for(register int i=0; i<n; i++){
            omega[i] = (i == 0) ? 1 : omega[i - 1] * tmp % MOD;
            omegaInverse[i] = inv(omega[i], MOD);
        }
    }

    inline void transform(long long *a, int n, long long *omega){
        for(register int i=0; i<n; i++){
            if(i < place[i]){
                swap(a[i], a[place[i]]);
            }
        }

        for(register int range=2; range <= n; range <<= 1){
            int mid = range >> 1;

            for(register long long *p = a; p != a + n; p += range){
                int k = n / range;

                for(register int i=0; i<mid; i++){
                    int tmp = i + mid;

                    long long t = omega[k * i] * p[i + mid] % MOD;

                    p[tmp] = (p[i] - t + MOD) % MOD;
                    p[i] = (p[i] + t) % MOD;
                }
            }
        }
    }

    inline void dft(long long *a, int n){
        transform(a, n, omega);
    }

    inline void idft(long long *a, int n){
        transform(a, n, omegaInverse);

        long long tmp = inv(n, MOD);

        for(register int i=0; i<n; i++){
            a[i] = a[i] * tmp % MOD;
        }
    }
}

inline int read(){
    int x = 0;
    int p = 1;
    char ch = getchar();

    while(ch < '0' || ch > '9'){
        if(ch == '-')
            p = 0;
        ch = getchar();
    }

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return p ? x : (-x);
}

int n, m, tot;

long long a[MAXN], b[MAXN];

int main(){
    n = read();
    m = read();

    n++;
    m++;

    for(register int i=0; i<n; i++)
        a[i] = read();

    for(register int i=0; i<m; i++){
        b[i] = read();
    }

    tot = 1;
    while(tot < n + m){
        tot <<= 1;
    }

    NTT::init(tot);

    NTT::dft(a, tot);
    NTT::dft(b, tot);

    for(register int i=0; i<tot; i++){
        a[i] = a[i] * b[i] % MOD;
    }

    NTT::idft(a, tot);

    for(register int i=0; i<n+m-1; i++){
        printf("%lld ", a[i]);
    }
    return 0;
}
C++

参考资料

评论 在线讨论
      加载更多