[模板] 多项式操作

引言

本文中介绍的 OI 常用多项式运算都是在模 $x^n$ 意义下进行的,并且系数对一个特定的数取模

我们约定本文中的未特殊说明的多项式次数为 $n-1$,即
$$
f(x) = \sum_{i=0}^{n-1} a_i x^i
$$

$$
a_{n-1} \not = 0
$$

多项式加法

两个多项式 $A(x)$ 和 $B(x)$ 的和即为对应的系数相加,复杂度为 $O(n)$

多项式乘法

两个多项式 $A(x)$ 和 $B(x)$ 的乘积可以在 $O(n \log n)$ 的时间内通过 $\mathrm{FFT}$ 或 $\mathrm{NTT}$ 求出

多项式求逆

对于多项式 $f(x)$,求出多项式 $g(x)$,使得

$$
f(x)g(x) \equiv 1 \pmod {x^n}
$$

我们通过倍增法求解,设已求得多项式 $g'(x)$,使得

$$
f(x)g'(x) \equiv 1 \pmod {x^{\lceil \frac{n}{2} \rceil}}
$$

那么

$$
\begin{aligned}
g(x) – g'(x) \equiv 0 & \pmod {x^{\lceil \frac{n}{2} \rceil}} \\
\left(g(x) – g'(x) \right)^2 \equiv 0 & \pmod {x^n} \\
g^2(x) + g’^2(x) – 2g(x)g'(x) \equiv 0 & \pmod {x^n}
\end{aligned}
$$

同时乘上 $f(x)$

$$
g(x) + f(x)g’^2(x) – 2g'(x) \equiv 0 \pmod {x^n}
$$

移项,得
$$
\begin{aligned}
g(x) & \equiv 2g'(x) – f(x)g’^2(x) \pmod {x^n} \\
& \equiv g'(x)(2-f(x)g'(x)) \pmod {x^n}
\end{aligned}
$$

复杂度
$$
T(n) = T(\frac{n}{2}) + O(n \log n) = O(n \log n)
$$

inline void getInverse(long long *a, long long *res, long long *tmp, int n){
    if(n == 1){
        res[0] = inv(a[0]);
        return;
    }   
    getInverse(a, res, tmp, (n + 1) >> 1);

    int p = 1, k = 0;
    while(p < (n << 1)) p <<= 1;
    while((1 << k) < p) k++;

    copy(a, a+n, tmp);
    fill(tmp+n, tmp+p, 0);

    NTT::getplace(p, k);
    NTT::dft(tmp, p);
    NTT::dft(res, p);

    for(register int i=0; i<p; i++){
        res[i] = (res[i] * (2 - tmp[i] * res[i] % MOD)) % MOD;
        if(res[i] < 0) res[i] += MOD;
    }

    NTT::idft(res, p);
    fill(res+n, res+p, 0);
}

多项式取模

给定一个 $n$ 次多项式 $F(x)$ 和一个 $m$ 次多项式 $G(x)$,求多项式 $Q(x)$,$R(x)$ 满足

  • $Q(x)$ 次数为 $n-m$,$R(x)$ 次数小于 $m$
  • $F(x) = Q(x) * G(x) + R(x)$

考虑存在一种操作 $R$,可以翻转整个多项式系数

$$
A_R(x) = x^n A(\frac{1}{x})
$$

$$
\begin{aligned}
F(x) & = Q(x) * G(x) + R(x) \\
F(\frac{1}{x}) & = Q(\frac{1}{x}) * G(\frac{1}{x}) + R(\frac{1}{x}) \\
x^n F(\frac{1}{x}) & = x^{n-m}*Q(\frac{1}{x}) * x^m * G(\frac{1}{x}) + x^{n-m-1} * x^{m+1} * R(\frac{1}{x}) \\
F_R(x) & = Q_R(x) * G_R(x) + x^{n-m-1} R_R(x) \\
F_R(x) & = Q_R(x) * G_R(x) \pmod {x^{n-m-1}} \\
Q_R(x) & = \frac{F_R(x)}{G_R(x)}
\end{aligned}
$$

最后 $R(x)$ 可以由

$$R(x) = F(x) – Q(x) * G(x)$$

得到

inline void getDiv(long long *a, long long *b, long long *resA, long long *resB, long long *tmp, long long *tmp2, long long *tmp3, long long *tmp4, int n, int m){
    int p = 1, k = 0;

    while(p < (n << 1))
        p <<= 1;

    while((1 << k) < p)
        k++;

    fill(tmp, tmp+p, 0);
    fill(tmp2, tmp2+p, 0);

    copy(a, a+n, tmp);
    copy(b, b+m, tmp2);

    reverse(tmp, tmp+n);
    reverse(tmp2, tmp2+m);

    getInverse(tmp2, tmp3, tmp4, n-m+1);

    NTT::getplace(p, k);
    NTT::dft(tmp, p);
    NTT::dft(tmp3, p);

    for(register int i=0; i<p; i++){
        resA[i] = tmp[i] * tmp3[i] % MOD;
    }

    NTT::idft(resA, p);
    reverse(resA, resA+(n-m+1));
    fill(resA+(n-m+1), resA+p, 0);

    copy(b, b+m, tmp);
    fill(tmp+n, tmp+p, 0);

    copy(resA, resA+n, tmp2);
    fill(tmp2+n, tmp2+p, 0);

    p = 1, k = 0;

    while(p < (n << 1)) p <<= 1;
    while((1 << k) < p) k++;

    NTT::getplace(p, k);
    NTT::dft(tmp, p);
    NTT::dft(tmp2, p);

    for(register int i=0; i<p; i++)
        tmp[i] = tmp[i] * tmp2[i] % MOD;

    NTT::idft(tmp, p);

    for(register int i=0; i<m-1; i++){
        resB[i] = (a[i] - tmp[i] + MOD) % MOD;
    }
} 

多项式求导与积分

$$
\frac{\mathrm{d}}{\mathrm{d}x} x^a = (a+1) x^{a-1}
$$

inline void getDerivation(long long *a, int n){
    for(register int i=0; i<n-1; i++){
        a[i] = a[i+1] * (i + 1) % MOD;
    }
    a[n-1] = 0;
}

inline void getQuadrature(long long *a, int n){
    for(register int i=n-1; i>=1; i--){
        a[i] = a[i-1] * inv(i) % MOD;
    }
    a[0] = 0;   
}

多项式 ln

求多项式 $g(x)$ 满足
$$
g(x) = \ln f(x)
$$

多项式的自然对数定义为在一个多项式和麦克劳林级数(函数在自变量零点求得的泰勒级数)的复合,常数项必须为 $1$

$$
ln(1-f(x)) = -\sum_{i=1}^{\infty} \frac{f^{(i)} (x)}{i!}
$$

求导可得

$$
(\ln f(x))’ = \frac{f'(x)}{f(x)}
$$

所以

$$
\ln f(x) = \int \frac{f'(x)}{f(x)} \mathrm{d} x
$$

复杂度 $O(n \log n)$

inline void getLn(long long *a, long long *res, long long *tmp, int n){
    getInverse(a, res, tmp, n);

    int p = 1, k = 0;
    while(p < (n << 1)) p <<= 1;
    while((1 << k) < p) k++;

    copy(a, a+n, tmp);
    fill(tmp+n, tmp+p, 0);

    getDerivation(tmp, n);

    NTT::getplace(p, k);
    NTT::dft(tmp, p);
    NTT::dft(res, p);   

    for(register int i=0; i<p; i++){
        res[i] = tmp[i] * res[i] % MOD;
    }

    NTT::idft(res, p);

    getQuadrature(res, n);
    fill(res+n, res+p, 0);
}

多项式 exp

求多项式 $g(x)$ 满足
$$
g(x) = \mathrm{e}^{f(x)}
$$

因为

$$
\ln g(x) = f(x)
$$

所以

$$
\ln g(x) – f(x) \equiv 0 \pmod {x^n}
$$

使用牛顿迭代

$$
\begin{aligned}
g(x) & = g_0(x) – \frac{\ln g(x) – f(x)}{\frac{1}{g(x)}} \\
& = g_0(x) – g(x) \ln g(x) + f(x)g(x) \\
& = g_0(x) + g(x)(f(x) – \ln g(x))
\end{aligned}
$$

在迭代过程中 $g_0(x)$ 和 $g(x)$ 可以看做相等

$$
g(x) = g(x)(f(x) – \ln(x) + 1)
$$

复杂度

$$
T(n) = T(\frac{n}{2}) + O(n \log n) = O(n \log n)
$$

inline void getExp(long long *a, long long *res, long long *tmp, long long *tmp2, int n){
    if(n == 1){
        res[0] = 1;
        return;
    }

    getExp(a, res, tmp, tmp2, (n + 1) >> 1);

    int p = 1, k = 0;
    while(p < (n << 1)) p <<= 1;
    while((1 << k) < p) k++;

    fill(tmp, tmp+p, 0);
    getLn(res, tmp, tmp2, n);

    for(register int i=0; i<n; i++){
        tmp2[i] = a[i] - tmp[i] % MOD;
        if(tmp2[i] < 0)
            tmp2[i] += MOD;
    }
    ++tmp2[0];

    fill(tmp2+n, tmp2+p, 0);

    NTT::getplace(p, k);
    NTT::dft(res, p);
    NTT::dft(tmp2, p);

    for(register int i=0; i<p; i++){
        res[i] = res[i] * tmp2[i] % MOD;
    }

    NTT::idft(res, p);
    fill(res+n, res+p, 0);
}

多项式开根

求多项式 $g(x)$ 满足

$$
g(x)^2 = f(x)
$$

假设已经求得多项式 $g'(x)$ 满足

$$
g'(x)^2 \equiv 0 \pmod {x^{\lceil \frac{n}{2} \rceil}}
$$

那么存在

$$
g(x)^2 – g'(x)^2 \equiv 0 \pmod {x^{\lceil \frac{n}{2} \rceil}}
$$

平方差公式展开

$$
(g(x) – g'(x)) (g(x) + g'(x)) \equiv 0 (\pmod {x^{\lceil \frac{n}{2} \rceil}}
$$

考虑其中一种

$$
g(x) – g'(x) \equiv 0 \pmod {x^{\lceil \frac{n}{2} \rceil}}
$$

两边平方

$$
\begin{aligned}
g(x)^2 – g'(x)^2 – 2g(x)g'(x) \equiv 0 & \pmod {x^n} \\
f(x) – g'(x)^2 – 2g(x)g'(x) \equiv 0 & \pmod {x^n}
\end{aligned}
$$

那么

$$
g(x) \equiv \frac{g'(x)^2 – f(x)}{2g'(x)} \pmod {x^n}
$$

复杂度
$$
T(n) = T(\frac{n}{2}) + O(n \log n) = O(n \log n)
$$

inline void getSqrt(long long *a, long long *res, long long *tmp, long long *tmp2, int n){
    if(n == 1){
        res[0] = sqrt(a[0]);
        return;
    } 

    getSqrt(a, res, tmp, tmp2, (n + 1) >> 1);

    int p = 1, k = 0;
    while(p < (n << 1)) p <<= 1;
    while((1 << k) < p) k++;

    fill(tmp, tmp+p, 0);
    getInverse(res, tmp, tmp2, n);

    copy(a, a+n, tmp2);
    fill(tmp2+n, tmp2+p, 0);

    NTT::getplace(p, k);
    NTT::dft(tmp, p);
    NTT::dft(tmp2, p);

    for(register int i=0; i<p; i++){
        tmp[i] = (tmp2[i] * tmp[i]) % MOD; 
    }

    NTT::idft(tmp, p);

    long long _inv = inv(2);

    for(register int i=0; i<n; i++){
        res[i] = (tmp[i] + res[i]) % MOD * _inv % MOD;
    }

    fill(res+n, res+p, 0);
}

多项式求幂

求多项式 $g(x)$ 满足
$$
g(x) = f(x)^k
$$

因为

$$
\ln g(x) = \ln k f(x) = k \ln f(x)
$$

所以

$$
g(x) = \mathrm{e}^{k \ln f(x)}
$$

复杂度 $O(n \log n)$

inline void getPow(long long *a, long long *res, long long *tmp, long long *tmp2, long long *tmp3, int n, long long k){
    int p = 1;
    while(p < (n << 1)) p <<= 1;

    fill(tmp, tmp+p, 0);
    getLn(a, tmp, tmp2, n);

    for(register int i=0; i<n; i++){
        tmp3[i] = tmp[i] * k % MOD;
    }

    getExp(tmp3, res, tmp, tmp2, n); 
}

模板

// http://cogs.pro:8080/cogs/problem/problem.php?pid=2189
// COGS 2189 帕秋莉的超级多项式

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MOD = 998244353;
const int MAXN = 1048576;

inline long long powx(long long a, long long b){
    long long ans = 1;
    for(; b; b >>= 1){
        if(b & 1) (ans *= a) %= MOD;
        (a *= a) %= MOD;
    }
    return ans;
}

inline long long exgcd(long long a, long long b, long long &x, long long &y){
    if(!b){x = 1, y = 0; return a;}
    long long d = exgcd(b, a%b, y, x);
    y -= x * (a / b);
    return d;
}

inline long long inv(long long a){
    long long x, y;
    exgcd(a, MOD, x, y);
    return (x % MOD + MOD) % MOD;
}

namespace NTT{
    long long omega[MAXN], omegaInverse[MAXN];
    int place[MAXN], tot;

    inline void getplace(int n, int k){
        for(register int i=0; i<n; i++)
            place[i] = place[i >> 1] >> 1 | (i & 1) << (k - 1);
    }

    inline void init(int n){
        long long g = 3;
        long long tmp = powx(g, (MOD - 1) / n);
        long long tmp2 = inv(tmp);
        for(register int i=0; i<(tot=n); i++){
            omega[i] = (i == 0 ? 1 : omega[i-1] * tmp % MOD);
            omegaInverse[i] = (i == 0 ? 1 : omegaInverse[i-1] * tmp2 % MOD);
        }
    }

    inline void transform(long long *data, int n, long long *omega){
        for(register int i=0; i<n; i++){
            if(i < place[i])
                swap(data[i], data[place[i]]);
        }
        for(register int range=2; range <= n; range <<= 1){
            int mid = range >> 1;
            for(register long long *a = data; a != data + n; a += range){
                for(register int i=0; i<mid; i++){
                    long long tmp = a[mid + i] * omega[tot / range * i] % MOD;
                    a[mid + i] = (a[i] - tmp + MOD) % MOD; 
                    a[i] = (a[i] + tmp) % MOD;
                }
            }
        }
    }

    inline void dft(long long *data, int n){
        transform(data, n, omega);
    }

    inline void idft(long long *data, int n){
        transform(data, n, omegaInverse);
        long long tmp = inv(n);
        for(register int i=0; i<n; i++){
            data[i] = data[i] * tmp % MOD;
        }
    }
}

inline void getInverse(long long *a, long long *res, long long *tmp, int n){
    if(n == 1){
        res[0] = inv(a[0]);
        return;
    }   
    getInverse(a, res, tmp, (n + 1) >> 1);

    int p = 1, k = 0;
    while(p < (n << 1)) p <<= 1;
    while((1 << k) < p) k++;

    copy(a, a+n, tmp);
    fill(tmp+n, tmp+p, 0);

    NTT::getplace(p, k);
    NTT::dft(tmp, p);
    NTT::dft(res, p);

    for(register int i=0; i<p; i++){
        res[i] = (res[i] * (2 - tmp[i] * res[i] % MOD)) % MOD;
        if(res[i] < 0) res[i] += MOD;
    }

    NTT::idft(res, p);
    fill(res+n, res+p, 0);
}

inline void getDerivation(long long *a, int n){
    for(register int i=0; i<n-1; i++){
        a[i] = a[i+1] * (i + 1) % MOD;
    }
    a[n-1] = 0;
}

inline void getQuadrature(long long *a, int n){
    for(register int i=n-1; i>=1; i--){
        a[i] = a[i-1] * inv(i) % MOD;
    }
    a[0] = 0;   
}

inline void getLn(long long *a, long long *res, long long *tmp, int n){
    getInverse(a, res, tmp, n);

    int p = 1, k = 0;
    while(p < (n << 1)) p <<= 1;
    while((1 << k) < p) k++;

    copy(a, a+n, tmp);
    fill(tmp+n, tmp+p, 0);

    getDerivation(tmp, n);

    NTT::getplace(p, k);
    NTT::dft(tmp, p);
    NTT::dft(res, p);   

    for(register int i=0; i<p; i++){
        res[i] = tmp[i] * res[i] % MOD;
    }

    NTT::idft(res, p);

    getQuadrature(res, n);
    fill(res+n, res+p, 0);
}

inline void getExp(long long *a, long long *res, long long *tmp, long long *tmp2, int n){
    if(n == 1){
        res[0] = 1;
        return;
    }

    getExp(a, res, tmp, tmp2, (n + 1) >> 1);

    int p = 1, k = 0;
    while(p < (n << 1)) p <<= 1;
    while((1 << k) < p) k++;

    fill(tmp, tmp+p, 0);
    getLn(res, tmp, tmp2, n);

    for(register int i=0; i<n; i++){
        tmp2[i] = a[i] - tmp[i] % MOD;
        if(tmp2[i] < 0)
            tmp2[i] += MOD;
    }
    ++tmp2[0];

    fill(tmp2+n, tmp2+p, 0);

    NTT::getplace(p, k);
    NTT::dft(res, p);
    NTT::dft(tmp2, p);

    for(register int i=0; i<p; i++){
        res[i] = res[i] * tmp2[i] % MOD;
    }

    NTT::idft(res, p);
    fill(res+n, res+p, 0);
}

inline void getSqrt(long long *a, long long *res, long long *tmp, long long *tmp2, int n){
    if(n == 1){
        res[0] = sqrt(a[0]);
        return;
    } 

    getSqrt(a, res, tmp, tmp2, (n + 1) >> 1);

    int p = 1, k = 0;
    while(p < (n << 1)) p <<= 1;
    while((1 << k) < p) k++;

    fill(tmp, tmp+p, 0);
    getInverse(res, tmp, tmp2, n);

    copy(a, a+n, tmp2);
    fill(tmp2+n, tmp2+p, 0);

    NTT::getplace(p, k);
    NTT::dft(tmp, p);
    NTT::dft(tmp2, p);

    for(register int i=0; i<p; i++){
        tmp[i] = (tmp2[i] * tmp[i]) % MOD; 
    }

    NTT::idft(tmp, p);

    long long _inv = inv(2);

    for(register int i=0; i<n; i++){
        res[i] = (tmp[i] + res[i]) % MOD * _inv % MOD;
    }

    fill(res+n, res+p, 0);
}

inline void getPow(long long *a, long long *res, long long *tmp, long long *tmp2, long long *tmp3, int n, long long k){
    int p = 1;
    while(p < (n << 1)) p <<= 1;

    fill(tmp, tmp+p, 0);
    getLn(a, tmp, tmp2, n);

    for(register int i=0; i<n; i++){
        tmp3[i] = tmp[i] * k % MOD;
    }

    getExp(tmp3, res, tmp, tmp2, n); 
}

inline int read(){
    int x = 0;
    char ch = getchar();

    while(ch < '0' || ch > '9')
        ch = getchar();

    while('0' <= ch && ch <= '9'){
        x = x*10 + ch - '0';
        ch = getchar();
    }

    return x;
}

int n;
long long k, a[MAXN], b[MAXN], tmp[MAXN], tmp2[MAXN], tmp3[MAXN];

int main(){
    freopen("polynomial.in","r",stdin);
    freopen("polynomial.out","w",stdout);

    n = read();
    k = read();

    for(register int i=0; i<n; i++){
        a[i] = read();
    }

    int p = 1;

    while(p < (n << 1))
        p <<= 1;

    NTT::init(p);

    getSqrt(a, b, tmp, tmp2, n);

    copy(b, b+n, a);
    fill(b, b+p, 0);
    getInverse(a, b, tmp, n);

    copy(b, b+n, a);
    fill(b, b+p, 0);
    getQuadrature(a, n);

    getExp(a, b, tmp, tmp2, n);

    copy(b, b+n, a);
    fill(b, b+p, 0);
    getInverse(a, b, tmp, n);

    ++b[0];
    copy(b, b+n, a);
    fill(b, b+p, 0);
    getLn(a, b, tmp, n);

    ++b[0];
    copy(b, b+n, a);
    fill(b, b+p, 0);
    getPow(a, b, tmp, tmp2, tmp3, n, k);

    copy(b, b+n, a);
    fill(b, b+p, 0);    
    getDerivation(a, n);

    for(register int i=0; i<n; i++){
        printf("%lld ", a[i]);
    }
}