[板子/笔记] FFT/NTT/MTT

两个月之后来补blog真的是石乐志

FFT

在OI中有需要快速计算多项式乘法的问题,需要用到FFT(/NTT/MTT)这类的东西。

首先有$n$次多项式$A(x) = \sum_{i = 0}^{n - 1}{ {a}_{i}{x}^{i} }$,$B(x) = \sum_{i = 0}^{n - 1}{ {b}_{i}{x}^{i} }$。
$C(x) = A(x) * B(x) = \sum_{k = 0}^{2n - 1}{\sum_{i = 0}^{k}{ {a}_{i}{b}_{k - i}{x}^{k} } }$
直接乘$\Theta(n ^ 2)$不是很资瓷。
我们想到用点值表示。$n + 1$个点$(x_i,A(x_i))$表示n次多项式$A(x)$。
找$2n + 1$个点带入$A(x)$和$B(x)$,相乘再转化成系数表示就资瓷了。

随意找点吗?不行。我们要找特殊点优化方法。
用单位根$e^{\frac{2 {\pi} i}{n} }$带入,就是FFT的核心。
我们设单位根${\omega}_{n}^{i} = e^{\frac{2 {\pi} i}{n} }$。
单位根有两个性质:

$ {\omega}_{n}^{i} = {\omega}_{2n}^{2i} $
$ {\omega}_{n}^{i + {\frac{n}{2} } } = -{\omega}_{n}^{i} $

然后利用性质,我们可以把$A(x)$拆成

而利用单位根我们有操作:

于是可以分治解决问题。
逆变换就是把单位根换成单位根的倒数,做FFT再每项除n。懒得证也不会证qwq

Code :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <cstdio>
#include <cstring>
#include <vector>
#include <complex>
using namespace std;
const double pi = acos(-1.0);
vector< complex<double> > a,b;
int n,m,x;
void FFT(vector< complex<double> >&a,int n,int rev)
{
if(n == 1)return;
vector< complex<double> > w0,w1;for(int i = 0;i < n;i+=2)w0.push_back(a[i]),w1.push_back(a[i + 1]);
FFT(w0,n >> 1,rev);FFT(w1,n >> 1,rev);
complex<double> wn(cos(2 * pi / n),sin(rev * 2 * pi / n)),w(1,0);
for(int i = 0;i < (n >> 1);i++,w *= wn) a[i] = w0[i] + w * w1[i],a[i + (n >> 1)] = w0[i] - w * w1[i];
}
int main()
{
scanf("%d %d",&n,&m);
for(int i = 0;i <= n;i++) {scanf("%d",&x);a.push_back(complex<double>(x));}
for(int i = 0;i <= m;i++) {scanf("%d",&x);b.push_back(complex<double>(x));}
m += n;n = 1;while(n <= m)n <<= 1;
while(a.size() <= n)a.push_back(complex<double>());
while(b.size() <= n)b.push_back(complex<double>());
FFT(a,n,1);FFT(b,n,1);
for(int i = 0;i <= n;i++) a[i] *= b[i];
FFT(a,n,-1);
for(int i = 0;i <= m;i++)printf("%d ",(int)(a[i].real() / n + 0.5)); putchar(10);
return 0;
}

迭代FFT

递归的太慢,我们需要把它改成迭代形式。
以长度为8的多项式为例观察FFT递归处理顺序:

000 001 010 011 100 101 110 111
0 1 2 3 4 5 6 7
0 2 4 6 1 3 5 7
0 4 2 6 1 5 3 7
000 100 010 110 001 101 011 111

即分治到边界的下标等于原下标的二进制翻转。

假设$A_{0}({\omega}_{\frac{n}{2} }^{i})$ 和 $A_{1}({\omega}_{\frac{n}{2} }^{i})$ 被存在 $\alpha_0$ 和 $\alpha_1$。
$A({\omega}_{n}^{i})$ 和 $A({\omega}_{n}^{i + \frac{n}{2}})$ 将被存放在 $\beta_0$ 和 $\beta_1$。
则有

我们可以让它们在原地完成此操作。令临时变量$t = {\omega}_{n}^{i} \alpha_1$
则有

Code :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <cctype>
const double pi = acos(-1.0);
const int N = 270010;
inline int getint()
{
int r = 0,s = 1;char c = getchar();for(;!isdigit(c);c = getchar()) if(c == '-') s = 0;
for(;isdigit(c);c = getchar()) r = (((r << 2) + r) << 1) + (c ^ '0');return s ? r : -r;
}
struct complex
{
double re,im;
complex(){re = im = 0.0;}
complex(double r,double i){re = r;im = i;}
inline friend complex operator + (const complex &a,const complex &b)
{
return complex(a.re + b.re,a.im + b.im);
}
inline friend complex operator - (const complex &a,const complex &b)
{
return complex(a.re - b.re,a.im - b.im);
}
inline friend complex operator * (const complex &a,const complex &b)
{
return complex(a.re * b.re - a.im * b.im,a.re * b.im + a.im * b.re);
}
inline void operator = (const complex &a)
{
this->re = a.re;this->im = a.im;
}
inline void operator += (const complex &a)
{
this->re += a.re;this->im += a.im;
}
inline void operator *= (const complex &a)
{
double r = this->re,i = this->im;
this->re = a.re * r - a.im * i;
this->im = a.re * i + a.im * r;
}
inline void operator /= (const double &a)
{
this->re /= a;this->im /= a;
}
};
struct FFT
{
int k,n;complex c1[N],c2[N],o[N],io[N];
inline void init()
{
k = 0;while((1 << k) < n) k++;
for(int i = 0;i < n;i++)
{
double a = cos(2.0 * pi / n * i),b = sin(2.0 * pi / n * i);
o[i] = (complex){a,b};io[i] = (complex){a,-b};
}
}
inline void trans(complex *a,complex *omega)
{
for(int i = 0;i < n;i++)
{
int t = 0;
for(int j = 0;j < k;j++) if(i & (1 << j)) t |= (1 << (k - j - 1));
if(i < t) std::swap(a[i],a[t]);
}
for(int l = 2;l <= n;l <<= 1)
{
int m = l >> 1;
for(complex *p = a;p != a + n;p += l)
{
for(int i = 0;i < m;i++)
{
complex t = omega[n / l * i] * p[m + i];
p[m + i] = p[i] - t;
p[i] += t;
}
}
}
}
inline void dft(complex *a)
{
trans(a,o);
}
inline void idft(complex *a)
{
trans(a,io);
for(int i = 0;i < n;i++) a[i] /= (1.0 * n);
}
void calc(const int *a,const int *b,const int n1,const int n2,int *ans)
{
n = 1;while(n < n1 + n2) n <<= 1;
for(int i = 0;i < n1;i++) c1[i].re = 1.0 * a[i];
for(int i = 0;i < n2;i++) c2[i].re = 1.0 * b[i];
init();
dft(c1);dft(c2);
for(int i = 0;i < n;i++) c1[i] *= c2[i];
idft(c1);
for(int i = 0;i < n1 + n2 - 1;i++) ans[i] = (int)floor(0.5 + c1[i].re);
}
}fft;
int n,m,a[N],b[N],ans[N];
int main()
{
n = getint() + 1,m = getint() + 1;
for(int i = 0;i < n;i++) a[i] = getint();
for(int i = 0;i < m;i++) b[i] = getint();
fft.calc(a,b,n,m,ans);
for(int i = 0;i < n + m - 1;i++) printf("%d%c",ans[i]," \n"[i == n + m - 2]);
return 0;
}

貌似FFT讲完了

NTT

某些时候,在精度不足或是在模P域下我们需要用到NTT。用原根代替单位根。
把$g^{\frac{P - 1}{n} }$(g一般取3)当作单位根的等价。

貌似NTT也讲完了

Code :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <cctype>
typedef long long LL;
const int G = 3,P = 998244353,N = 270010;
inline int getint()
{
int r = 0,s = 1;char c = getchar();for(;!isdigit(c);c = getchar()) if(c == '-') s = 0;
for(;isdigit(c);c = getchar()) r = (((r << 2) + r) << 1) + (c ^ '0');return s ? r : -r;
}
inline int qpow(int a,int b)
{
int ans = 1;a %= P;b %= P;
while(b)
{
if(b & 1) ans = 1LL * ans * a % P;
b >>= 1;
a = 1LL * a * a % P;
}
return ans;
}
struct NTT
{
int k,n;int c1[N],c2[N],o[N],io[N];
inline void init()
{
k = 0;while((1 << k) < n) k++;
}
inline void trans(int *a,int f)
{
for(int i = 0;i < n;i++)
{
int t = 0;
for(int j = 0;j < k;j++) if(i & (1 << j)) t |= (1 << (k - j - 1));
if(i < t) std::swap(a[i],a[t]);
}
for(int l = 2;l <= n;l <<= 1)
{
int m = l >> 1;
LL wn = qpow(G,f == 1 ? (P - 1) / l : P - 1 - (P - 1) / l);
for(int *p = a;p != a + n;p += l)
{
LL w = 1;
for(int i = 0;i < m;i++)
{
int t = 1LL * w * p[m + i] % P;
p[m + i] = (1LL * p[i] + P - t) % P;
p[i] = (1LL * p[i] + t) % P;
w = w * wn % P;
}
}
}
}
inline void dft(int *a)
{
trans(a,1);
}
inline void idft(int *a)
{
trans(a,-1);
int inv = qpow(n,P - 2);
for(int i = 0;i < n;i++) a[i] = 1LL * a[i] * inv % P;
}
void calc(const int *a,const int *b,const int n1,const int n2,int *ans)
{
n = 1;while(n < n1 + n2) n <<= 1;
std::memcpy(c1,a,sizeof(int) * n1);
std::memcpy(c2,b,sizeof(int) * n2);
init();
dft(c1);dft(c2);
for(int i = 0;i < n;i++) c1[i] = 1LL * c1[i] * c2[i] % P;
idft(c1);
for(int i = 0;i < n1 + n2 - 1;i++) ans[i] = c1[i];
}
}ntt;
int n,m,a[N],b[N],ans[N];
int main()
{
n = getint() + 1;m = getint() + 1;
for(int i = 0;i < n;i++) a[i] = getint();
for(int i = 0;i < m;i++) b[i] = getint();
ntt.calc(a,b,n,m,ans);
for(int i = 0;i < n + m - 1;i++) printf("%d%c",ans[i]," \n"[n + m == i + 2]);
return 0;
}

三模数NTT(MTT(orz myy))

机智的可达鸭发现事情并不对劲.jpg
NTT中模数被限定为费马质数,即(P-1)有超过序列长度的2的正整数幂因子的质数。不是怎么办?
我们对3个不同模数做3次NTT,再掏出中国剩余定理就资瓷啦。

问题转化成

求$S mod p$

先处理前两个方程,得

问题转化成

设$S = k M + A = x p_3 + a_3$,有

带回去模p就可以了。

Code :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <cctype>
typedef long long LL;
const int N = 270010;
const int G = 3,MOD[3] = {998244353,1004535809,469762049};
inline int getint()
{
int r = 0,s = 1;char c = getchar();for(;!isdigit(c);c = getchar()) if(c == '-') s = 0;
for(;isdigit(c);c = getchar()) r = (((r << 2) + r) << 1) + (c ^ '0');return s ? r : -r;
}
inline int qpow(int a,int b,int m)
{
int ans = 1;a %= m;b %= m;
while(b)
{
if(b & 1) ans = 1LL * ans * a % m;
a = 1LL * a * a % m;
b >>= 1;
}
return ans;
}
struct MTT
{
struct NTT
{
int k,n;int c1[N],c2[N],o[N],io[N];
int P;
inline void init()
{
k = 0;while((1 << k) < n) k++;
}
inline void trans(int *a,int f)
{
for(int i = 0;i < n;i++)
{
int t = 0;
for(int j = 0;j < k;j++) if(i & (1 << j)) t |= (1 << (k - j - 1));
if(i < t) std::swap(a[i],a[t]);
}
for(int l = 2;l <= n;l <<= 1)
{
int m = l >> 1;
LL wn = qpow(G,f == 1 ? (P - 1) / l : P - 1 - (P - 1) / l,P);
for(int *p = a;p != a + n;p += l)
{
LL w = 1;
for(int i = 0;i < m;i++)
{
int t = 1LL * w * p[m + i] % P;
p[m + i] = (1LL * p[i] + P - t) % P;
p[i] = (1LL * p[i] + t) % P;
w = w * wn % P;
}
}
}
}
inline void dft(int *a)
{
trans(a,1);
}
inline void idft(int *a)
{
trans(a,-1);
int inv = qpow(n,P - 2,P);
for(int i = 0;i < n;i++) a[i] = 1LL * a[i] * inv % P;
}
void calc(const int *a,const int *b,const int n1,const int n2,int *ans,const int p)
{
P = p;
n = 1;while(n < n1 + n2) n <<= 1;
std::memcpy(c1,a,sizeof(int) * n1);
std::memcpy(c2,b,sizeof(int) * n2);
init();
dft(c1);dft(c2);
for(int i = 0;i < n;i++) c1[i] = 1LL * c1[i] * c2[i] % P;
idft(c1);
for(int i = 0;i < n1 + n2 - 1;i++) ans[i] = c1[i];
}
}ntt[3];
struct CRT
{
int m;
const LL M = 1LL * MOD[0] * MOD[1];
const int inv[3] =
{
qpow(MOD[0] % MOD[1],MOD[1] - 2,MOD[1]),
qpow(MOD[1] % MOD[0],MOD[0] - 2,MOD[0]),
qpow(M % MOD[2],MOD[2] - 2,MOD[2])
};
inline LL mul(LL a,LL b,LL p)
{
a %= p;b %= p;
return ((a * b - (LL)((LL) ((long double)a / p * b + 1e-3) * p)) % p + p) % p;
}
int calc(int a,int b,int c)
{
LL A = (mul((1LL * a * MOD[1] % M),inv[1],M) + mul((1LL * b * MOD[0] % M),inv[0],M)) % M;
LL k = (1LL * c + MOD[2] - A % MOD[2]) * inv[2] % MOD[2];
return (k * (M % m) + A) % m;
}
}crt;
int ans[3][N];
void calc(const int *a,const int *b,const int n1,const int n2,int *ret,const int M)
{
for(int i = 0;i < 3;i++) ntt[i].calc(a,b,n1,n2,ans[i],MOD[i]);
crt.m = M;for(int i = 0;i < n1 + n2 - 1;i++) ret[i] = crt.calc(ans[0][i],ans[1][i],ans[2][i]);
}
}mtt;
int n,m,a[N],b[N],ans[N];
int main()
{
n = getint() + 1;m = getint() + 1;int p = getint();
for(int i = 0;i < n;i++) a[i] = getint();
for(int i = 0;i < m;i++) b[i] = getint();
mtt.calc(a,b,n,m,ans,p);
for(int i = 0;i < n + m - 1;i++) printf("%d%c",ans[i]," \n"[i + 2 == n + m]);
return 0;
}

(板子常数巨大谨慎食用)