十六岁没能送你玫瑰,二十六岁请你喝一次勇闯天涯叭。
[toc]
目前想完成的任务是:多项式之快速傅里叶变换(FFT)/数论变换(NTT)
前置知识
对复数和复平面有一定的了解
了解逆元,原根,中国剩余定理
对多项式有一定认识,能写$O(n^2)$的高精乘
对分治很了解UPD:本文之前只是介绍$FFT$,后来慢慢扩展得有点多,似乎前置知识有点多。
多项式
定义:
定义多项式为形如下式的代数表达式。
其中$a_0,a_1,a_2……a_n$称为多项式的系数。
最高项的指数n叫做多项式的度,$Degree ,n = deg P$,也可以说是多项式的系数。本文没特殊说明情况下默认$n = deg\ F$多项式的卷积形式
设有两个多项式$g(x), f(x)$,设他们的度数分别为$n , m$,则卷积具有如下形式:
基础的,我们可以通过$O(n*m)$来获得卷积结果。
Karatsuba乘法
值得一提的是,$Karatsuba$ 算法是第一个比小学二次乘法算法渐进快速的算法。
对于上面的卷积,$Karatsuba$ 提出如下方法:
对于多项式 $F$ ,不妨设$n = deg \ F+1$,此时有:
令$F(x) = F_0(x)+x^{\frac{n}{2}}F_1(x),G(x)=G_0(X)+x^{\frac{n}{2}}G_1(x)$
其中有:$deg\ F_0=deg\ F_1=deg\ G_0=deg\ G_1=\frac n2$
那么得到:
不让令$M(x)=((F_0+F_1)\times(G_0+G_1))(x)$
不难发现:
至此,我们只需要三个多项式的卷积 $M(x),(F_0\times G_0)(x)-(F_1\times G_1)(x)$ 即可。
采用分治的做法。时间复杂度为:
$python$ 的代码实现1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19# Python 2 and 3
def karatsuba(num1, num2):
num1Str, num2Str = str(num1), str(num2)
if num1Str[0] == '-': return -karatsuba(-num1, num2)
if num2Str[0] == '-': return -karatsuba(num1, -num2)
if num1 < 10 or num2 < 10: return num1 * num2
maxLength = max(len(num1Str), len(num2Str))
num1Str = ''.join(list('0' * maxLength)[:-len(num1Str)] + list(num1Str))
num2Str = ''.join(list('0' * maxLength)[:-len(num2Str)] + list(num2Str))
splitPosition = maxLength // 2
high1, low1 = int(num1Str[:-splitPosition]), int(num1Str[-splitPosition:])
high2, low2 = int(num2Str[:-splitPosition]), int(num2Str[-splitPosition:])
z0, z2 = karatsuba(low1, low2), karatsuba(high1, high2)
z1 = karatsuba((low1 + high1), (low2 + high2))
return z2 * 10 ** (2 * splitPosition) + (z1 - z2 - z0) * 10 ** (splitPosition) + z0
此外,Toom–Cook multiplication是此算法更快速的泛型。对于$n$足够大时还有Schönhage–Strassen algorithm算法是更快的,它的时间复杂度为$O(n\ log \ n\ log\ log n)$.
多项式的系数表示与点值表示
系数表示
对于多项式
我们将 $\mathbf{a}$ 数组看作$n+1$维向量 $\vec{a} = (a_0,a_1,\cdots,a_n)$,其系数表示就是向量$\vec a$。
点值表示
由小学知识可知,$n$ 个点$(x_i,y_i)$便可以唯一确定一个多项式$y=F(x)$
现在取任意$n+1$个点
拉格朗日插值法(Lagrange)
在知乎学数学.jpg 有点想吐槽一下最近知乎越来越…了,钓鱼问题乱飞。。
一般方法
重心拉格朗日插值法
复数与单位根
FFT(快速傅里叶变换)
为什么引入FFT
两个多项式相乘朴素算法的复杂度是$O(n^2)$而使用FFT优化之后可以把复杂度降为$O(nlogn)$
NTT(数论变换)
分治FFT
拆系数FFT与三模数FFT
一些题目及其思路
1. 【模板】多项式乘法
$FFT$:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
typedef double LD;
using namespace std;
const int MAXN = 5e6 + 10;
const LD PI = acos(-1);
struct C {
LD r, i;
C(LD r = 0, LD i = 0) : r(r), i(i) {}
} A[MAXN], B[MAXN];
C operator+(const C& a, const C& b) { return C(a.r + b.r, a.i + b.i); }
C operator-(const C& a, const C& b) { return C(a.r - b.r, a.i - b.i); }
C operator*(const C& a, const C& b) { return C(a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r); }
void FFT(C x[], int n, int p) {
for (int i = 0, t = 0; i < n; ++i) {
if (i > t)
swap(x[i], x[t]);
for (int j = n >> 1; (t ^= j) < j; j >>= 1)
;
}
for (int h = 2; h <= n; h <<= 1) {
C wn(cos(p * 2 * PI / h), sin(p * 2 * PI / h));
for (int i = 0; i < n; i += h) {
C w(1, 0), u;
for (int j = i, k = h >> 1; j < i + k; ++j) {
u = x[j + k] * w;
x[j + k] = x[j] - u;
x[j] = x[j] + u;
w = w * wn;
}
}
}
if (p == -1)
FOR(i, 0, n)
x[i].r /= n;
}
void conv(C a[], C b[], int n) {
FFT(a, n, 1);
FFT(b, n, 1);
FOR(i, 0, n)
a[i] = a[i] * b[i];
FFT(a, n, -1);
}
int a, b, n;
int main() {
std::ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
cin >> a >> b;
for (int i = 0; i <= a; i++) {
cin >> A[i].r;
}
for (int i = 0; i <= b; i++) {
cin >> B[i].r;
}
n = 1;
while (a + b >= n) {
n *= 2;
}
conv(A, B, n);
// cout << "n ==== " << n << endl;
for (int i = 0; i <= a + b; i++) {
cout << (int)(A[i].r + 0.5) << ' ';
}
cout << endl;
}
$NTT$1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
using LL = long long ;
using namespace std;
const int MOD = 998244353;
const int mod = 998244353;
int G = 3;
const int maxn = 5e6+10;
const int N = 1e6+10;
LL wn[N << 2], rev[N << 2];
LL bin(LL x, LL n, LL MOD) {
LL ret = MOD != 1;
for (x %= MOD; n; n >>= 1, x = x * x % MOD)
if (n & 1) ret = ret * x % MOD;
return ret;
}
inline LL get_inv(LL x, LL p) { return bin(x, p - 2, p); }
int NTT_init(int n_) {
int step = 0; int n = 1;
for ( ; n <= n_; n <<= 1) ++step;
FOR (i, 1, n)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (step - 1));
int g = bin(G, (MOD - 1) / n, MOD);
wn[0] = 1;
for (int i = 1; i <= n; ++i)
wn[i] = wn[i - 1] * g % MOD;
return n;
}
void NTT(LL a[], int n, int f) {
FOR (i, 0, n) if (i < rev[i])
std::swap(a[i], a[rev[i]]);
for (int k = 1; k < n; k <<= 1) {
for (int i = 0; i < n; i += (k << 1)) {
int t = n / (k << 1);
FOR (j, 0, k) {
LL w = f == 1 ? wn[t * j] : wn[n - t * j];
LL x = a[i + j];
LL y = a[i + j + k] * w % MOD;
a[i + j] = (x + y) % MOD;
a[i + j + k] = (x - y + MOD) % MOD;
}
}
}
if (f == -1) {
LL ninv = get_inv(n, MOD);
FOR (i, 0, n)
a[i] = a[i] * ninv % MOD;
}
}
int n , m, a[maxn], b[maxn],limit = 1, L;
int32_t main(){
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
cin >> n >> m;
for(int i = 0; i <= n; i++)cin >> a[i];
for(int j = 0; j <= m; j++)cin >> b[j];
while(limit <= n+m)limit++,L++;
limit = NTT_init(m+n);
//cout << "limit === "<<limit << endl;
NTT(a,limit , 1);NTT(b , limit, 1);
for(int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % MOD;
NTT(a,limit , -1);
for(int i = 0; i <= n+m; i++){
cout << (a[i] + mod) % mod <<" ";
}//cout << '\n';
}
2.【模板】A*B Problem升级版(FFT快速傅里叶)
给定两个大整数$A,B$,求$A\times B$
把$A$看作$A=\sum_{i=0}^{n-1}a_i*10^i$1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91// fft不同位数的乘法。A*B 看作A = 累加a*10^Pi
typedef double LD;
for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
using namespace std;
const int MAXN = 1e6 + 10;
const int maxn = 1e6 + 10;
const LD PI = acos(-1);
struct C {
LD r, i;
C(LD r = 0, LD i = 0) : r(r), i(i) {}
};
C operator+(const C& a, const C& b) { return C(a.r + b.r, a.i + b.i); }
C operator-(const C& a, const C& b) { return C(a.r - b.r, a.i - b.i); }
C operator*(const C& a, const C& b) {
return C(a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r);
}
void FFT(C x[], int n, int p) {
for (int i = 0, t = 0; i < n; ++i) {
if (i > t) swap(x[i], x[t]);
for (int j = n >> 1; (t ^= j) < j; j >>= 1)
;
}
for (int h = 2; h <= n; h <<= 1) {
C wn(cos(p * 2 * PI / h), sin(p * 2 * PI / h));
for (int i = 0; i < n; i += h) {
C w(1, 0), u;
for (int j = i, k = h >> 1; j < i + k; ++j) {
u = x[j + k] * w;
x[j + k] = x[j] - u;
x[j] = x[j] + u;
w = w * wn;
}
}
}
if (p == -1) FOR(i, 0, n)
x[i].r /= n;
}
void conv(C a[], C b[], int n) {
FFT(a, n, 1);
FFT(b, n, 1);
FOR(i, 0, n)
a[i] = a[i] * b[i];
FFT(a, n, -1);
}
int limit = 1, bit = 0; // limit为最终扩充的长度 limit = 1<<bit
int wz[maxn << 2];
int re[maxn << 2]; //存储结果
C a[maxn << 2], b[maxn << 2];
char s1[maxn], s2[maxn]; //存储两个整数
int main() {
scanf("%s%s", s1, s2);
int len1 = strlen(s1), len2 = strlen(s2);
while (limit <= len1 + len2) {
limit <<= 1;
bit++;
}
// cout << "limit === " << limit << endl << "bit ==== " << bit << endl;
for (int i = len1 - 1, j = 0; i >= 0; i--, j++) {
a[j].r = s1[i] - 48;
a[j].i = 0;
}
for (int i = len2 - 1, j = 0; i >= 0; i--, j++) {
b[j].r = s2[i] - 48;
b[j].i = 0;
}
// for(int i=0;i<limit;i++)
// wz[i]=(wz[i>>1]>>1)|((i&1)<<(bit-1));
conv(a, b, limit);
memset(re, 0, sizeof(re));
for (int i = 0; i <= limit; i++) {
re[i] += (int)(a[i].r + 0.5);
if (re[i] >= 10) //进位
{
re[i + 1] += re[i] / 10;
re[i] %= 10;
if (i == limit) ++limit;
}
}
while (!re[limit] && limit >= 1) //去除高位的0
limit--;
while (limit >= 0) printf("%d", re[limit--]);
printf("\n");
return 0;
}