[模板] 高精度乘法

Luogu – 1919 高精度乘法

题目描述

给出两个 $n$ 位 $10$ 进制整数 $x$ 和 $y$,你需要计算 $x \times y $

输入格式

第一行一个整数 $n$。第二行描述一个位数为 $n$ 的正整数 $x$,第三行描述一个位数为 $n$的正整数 $y$

输出格式

输出一行,即 $x \times y$ 的结果。(注意判断前导 $0$)

样例输入1

1
3
4

样例输出1

12

数据范围及提示

$ n \leq 60000 $

解题思路

本题是一道 $FFT$ 模板题,一个 $n$ 位的十进制整数,可以看做一个 $n-1$ 次的多项式 $F(x)$

$$F(x) = a_0 + a_1 \times 10 + a_2 \times 10^2 + \cdots + a_{n-1} \times 10^{n-1} $$

然后就可以直接使用 $FFT$ 进行高精度运算了。

不过需要注意,次数低的应该放在数组的前面,因此需要倒序读入,在输出答案的时候要注意进位,并且去除前导 $0$

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = 131072 + 10;
const int SIZE = 60000 + 10;
const double PI = acos(-1.0);

struct complex{
    double x, y;
    complex(double _x = 0, double _y = 0){x = _x, y = _y;}
};

complex operator + (complex a, complex b){ return complex(a.x+b.x, a.y+b.y); }
complex operator - (complex a, complex b){ return complex(a.x-b.x, a.y-b.y); }
complex operator * (complex a, complex b){ return complex(a.x*b.x - a.y*b.y, a.x*b.y + a.y*b.x); }

char s1[SIZE], s2[SIZE];

complex a[MAXN], b[MAXN];
complex omega[MAXN], omegaInverse[MAXN];

int n;
int place[MAXN];
int N = 1, k;

inline void init(int N){
    for(register int i=0; i<N; i++)
        for(register int j=0; j<k; j++)
            if(i & (1 << j))
                place[i] |= 1 << (k - j - 1);


    const double single = 2 * PI / N;

    for(register int i=0; i<N; i++){
        omega[i] = complex(cos(single * i), sin(single * i));
        omegaInverse[i] = complex(omega[i].x, -omega[i].y);
    }
}

inline void transform(complex *a, complex *omega){
    for(register int i=0; i<N; i++){
        if(i < place[i])
            swap(a[i], a[place[i]]);
    }

    for(register int range=2; range <= N; range <<= 1){
        int mid = range >> 1;
        int k = N / range;

        for(register complex *p = a; p != a + N; p += range){
            for(register int i=0; i<mid; i++){
                int tmp = i + mid;

                complex t = omega[k*i] * p[tmp];
                p[tmp] = p[i] - t;
                p[i] = p[i] + t; 
            }
        }
    }
}

int ans[MAXN];

int main(){

    scanf("%d", &n);
    scanf("%s%s", s1, s2);

    int a_cnt = 0;
    int b_cnt = 0;

    for(register int i=n-1; i>=0; i--)
        a[i].x = s1[a_cnt++] - '0';

    for(register int i=n-1; i>=0; i--)
        b[i].x = s2[b_cnt++] - '0';



    while(N < n + n){
        N <<= 1;
        k++;
    }

    init(N);

    transform(a, omega);
    transform(b, omega);

    for(register int i=0; i<N; i++)
        a[i] = a[i] * b[i];

    transform(a, omegaInverse);

    for(register int i=0; i<N; i++){
        ans[i] += floor(a[i].x/N + 0.5);

        if(ans[i] >= 10){
            ans[i+1] += ans[i]/10;
            ans[i] %= 10;

            if(i==N-1){
                N++;
            }
        }
    }

    while(!ans[N] && N > 1){
        N--;
    }

    for(register int i=N; i>=0; i--)
        printf("%d", ans[i]);

    return 0;
}