隐藏
「bsoj4125」FFT快速傅立叶 - FFT/NTT模板题 | Bill Yang's Blog

路终会有尽头,但视野总能看到更远的地方。

0%

「bsoj4125」FFT快速傅立叶 - FFT/NTT模板题

题目大意

    给出两个$n$位$10$进制整数$x$和$y$,你需要计算$x\times y$。


题目分析

FFT/NTT模板题。
注意是计算高精度数,需要排除高位$0$。

学习笔记见下面两个链接:


代码

FFT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include<algorithm>
#include<iostream>
#include<iomanip>
#include<cstring>
#include<complex>
#include<cstdlib>
#include<climits>
#include<complex>
#include<vector>
#include<cstdio>
#include<cmath>
#include<queue>
using namespace std;

#define cp complex<double>

inline const int Get_Int() {
int num=0,bj=1;
char x=getchar();
while(x<'0'||x>'9') {
if(x=='-')bj=-1;
x=getchar();
}
while(x>='0'&&x<='9') {
num=num*10+x-'0';
x=getchar();
}
return num*bj;
}

const int maxn=131072+5;
const double pi=acos(-1);

struct FastFourierTransform {
int n;
cp omega[maxn],iomega[maxn];

void init(int n) {
this->n=n;
for(int i=0; i<n; i++) {
omega[i]=cp(cos(2*pi/n*i),sin(2*pi/n*i));
iomega[i]=conj(omega[i]); //共轭复数
}
}

void transform(cp* a,cp* omega) {
int k=log2(n);
for(int i=0; i<n; i++) { //二进制位翻转
int t=0;
for(int j=0; j<k; j++)if(i&(1<<j))t|=(1<<(k-j-1));
if(i<t)swap(a[i],a[t]); //避免多次翻转
}
for(int len=2; len<=n; len*=2) {
int mid=len>>1;
for(cp* p=a; p!=a+n; p+=len)
for(int i=0; i<mid; i++) {
cp t=omega[n/len*i]*p[mid+i];
p[mid+i]=p[i]-t;
p[i]+=t;
}
}
}

void dft(cp* a) {
transform(a,omega);
}

void idft(cp* a) {
transform(a,iomega);
for(int i=0; i<n; i++)a[i]/=n;
}
} fft;

void Multiply(const int* a1,const int n1,const int* a2,const int n2,int* ans) {
int n=1;
while(n<n1+n2)n<<=1; //补成整数
static cp c1[maxn],c2[maxn];
for(int i=0; i<n1; i++)c1[i].real(a1[i]);
for(int i=0; i<n2; i++)c2[i].real(a2[i]);
fft.init(n);
fft.dft(c1);
fft.dft(c2);
for(int i=0; i<n; i++)c1[i]*=c2[i];
fft.idft(c1);
for(int i=0; i<n1+n2-1; i++)ans[i]=round(c1[i].real());
}

int n,a1[maxn],a2[maxn],ans[maxn];
char s1[maxn],s2[maxn];

int main() {
n=Get_Int();
scanf("%s",s1);
scanf("%s",s2);
for(int i=0; i<n; i++)a1[n-i-1]=s1[i]-'0';
for(int i=0; i<n; i++)a2[n-i-1]=s2[i]-'0';
Multiply(a1,n,a2,n,ans);
for(int i=0; i<2*n-1; i++) {
ans[i+1]+=ans[i]/10;
ans[i]%=10;
}
int len=2*n-1;
while(!ans[len])len--;
for(int i=len; i>=0; i--)printf("%d",ans[i]);
return 0;
}

NTT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include<algorithm>
#include<iostream>
#include<iomanip>
#include<cstring>
#include<complex>
#include<cstdlib>
#include<climits>
#include<vector>
#include<cstdio>
#include<cmath>
#include<queue>
using namespace std;

typedef long long LL;

inline const int Get_Int() {
int num=0,bj=1;
char x=getchar();
while(x<'0'||x>'9') {
if(x=='-')bj=-1;
x=getchar();
}
while(x>='0'&&x<='9') {
num=num*10+x-'0';
x=getchar();
}
return num*bj;
}

const int maxn=131072+5;
const LL mod=998244353;
LL g=3;

LL Quick_Pow(LL a,LL b) {
LL sum=1;
for(; b; b>>=1,a=a*a%mod)if(b&1)sum=sum*a%mod;
return sum;
}

LL inv(LL x) {
return Quick_Pow(x,mod-2);
}

vector<int> ppt;

void find_root(LL x) {
LL tmp=x-1;
for(int i=2; i<=sqrt(tmp); i++)
if(tmp%i==0) {
ppt.push_back(i);
while(tmp%i==0)tmp/=i;
}
for(g=2; g<x; g++) {
bool bj=1;
for(int t:ppt)
if(Quick_Pow(g,(x-1)/t)==1) {
bj=0;
break;
}
if(bj)return;
}
}

struct NumberTheoreticTransform {
int n;
LL omega[maxn],iomega[maxn];

void init(int n) {
this->n=n;
int x=Quick_Pow(g,(mod-1)/n);
omega[0]=iomega[0]=1;
for(int i=1; i<n; i++) {
omega[i]=omega[i-1]*x%mod;
iomega[i]=inv(omega[i]);
}
}

void transform(LL* a,LL* omega) {
int k=log2(n);
for(int i=0; i<n; i++) { //二进制位翻转
int t=0;
for(int j=0; j<k; j++)if(i&(1<<j))t|=(1<<(k-j-1));
if(i<t)swap(a[i],a[t]); //避免多次翻转
}
for(int len=2; len<=n; len*=2) {
int mid=len>>1;
for(LL* p=a; p!=a+n; p+=len)
for(int i=0; i<mid; i++) {
LL t=omega[n/len*i]*p[mid+i]%mod;
p[mid+i]=(p[i]-t+mod)%mod;
p[i]=(p[i]+t)%mod;
}
}
}

void dft(LL* a) {
transform(a,omega);
}

void idft(LL* a) {
transform(a,iomega);
LL x=inv(n);
for(int i=0; i<n; i++)a[i]=a[i]*x%mod;
}
} ntt;

void Multiply(LL* a1,const int n1,LL* a2,const int n2,LL* ans) {
int n=1;
while(n<n1+n2)n<<=1; //补成整数
ntt.init(n);
ntt.dft(a1);
ntt.dft(a2);
for(int i=0; i<n; i++)a1[i]=a1[i]*a2[i]%mod;
ntt.idft(a1);
for(int i=0; i<n1+n2-1; i++)ans[i]=a1[i];
}

int n;
LL a1[maxn],a2[maxn],ans[maxn];
char s1[maxn],s2[maxn];

int main() {
// find_root(mod);
n=Get_Int();
scanf("%s",s1);
scanf("%s",s2);
for(int i=0; i<n; i++)a1[n-i-1]=s1[i]-'0';
for(int i=0; i<n; i++)a2[n-i-1]=s2[i]-'0';
Multiply(a1,n,a2,n,ans);
for(int i=0; i<2*n-1; i++) {
ans[i+1]+=ans[i]/10;
ans[i]%=10;
}
int len=2*n-1;
while(!ans[len])len--;
for(int i=len; i>=0; i--)printf("%lld",ans[i]);
return 0;
}

姥爷们赏瓶冰阔落吧~