By Long Luo
之前的文章 快速傅里叶变换(FFT)算法 和 快速傅里叶变换(FFT)算法的实现及优化 详细介绍了 FFT 的具体实现及其实现。
FFT 优点很多,但缺点也很明显。例如单位复根的实部和虚部分别是一个正弦及余弦函数,有大量浮点数计算,计算量很大,而且浮点数运算产生的误差会比较大。
如果我们操作的对象都是整数的话,其实数学家已经发现了一个更好的方法:快速数论变换 (Number Theoretic Transform) 。
快速数论变换(NTT)
FFT 的本质是什么?
是什么让 FFT 做到了 O(nlogn) 的复杂度?
那有没有什么其他的东西也拥有单位根的这些性质呢?
答案当然是有的,原根就具有和单位根一样的性质。
所以快速数论变换 NTT 就是以数论为基础的具有循环卷积性质的,用有限域上的单位根来取代复平面上的单位根的 FFT。
原根
仿照单位复数根的形式,也将原根的取值看成一个圆,不过这个圆上只有有限个点,每个点表达的是模数的剩余系中的值。
在 FFT 中,我们总共用到了单位复根的这些性质:
- n 个单位复根互不相同;
- ωnk=ω2n2k;
- ωnk=−ωnk+n/2;
- ωna×ωnb=ωna+b。
我们发现原根具有和单位复根一样的性质,简单证明:
令 n 为大于 1 的 2 的幂,p 为素数且 n∣(p−1),g 为 p 的一个原根。
我们设 gn=gnp−1:
-
gnn=gn⋅np−1=gp−1
-
gn2n=g2p−1
-
ganak=ganak(p−1)=gnk(p−1)=gnk
显然
-
gnn≡1(modp)
-
gn2n≡−1(modp)
-
(gnk+2n)2=gn2k+n≡gn2k(modp)
证毕。
所以将 gnk 和 gnk+2n 带入本质上和将 ωnk 和 ωnk+2n 带入的操作无异。
利用Vandermonde矩阵性质,类似 NTT 那样,我们可以从 NTT 变换得到逆变换 INTT 变换,设 x(n) 为整数序列,则有:
NTT : X(m)=i=0∑Nx(n)amn(modM)
INTT : X(m)=N−1i=0∑Nx(n)a−mn(modM)
这里 N−1,a−mn(modM) 为模意义下的乘法逆元。
当然, NTT 也是有自己的缺点的:比如不能够处理小数的情况,以及不能够处理没有模数的情况。对于模数的选取也有一定的要求,首先是必须要有原根,其次是必须要是 2 的较高幂次的倍数。
NTT 实现
通过上面的分析,开始写代码吧:-)
NTT 也有递归版(Recursion)和迭代版(Iteration) 2 种实现:
递归版(Recursion)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| const long long G = 3; const long long G_INV = 332748118; const long long MOD = 998244353;
vector<int> rev;
long long quickPower(long long a, long long b) { long long res = 1;
while (b > 0) { if (b & 1) { res = (res * a) % MOD; }
a = (a * a) % MOD; b >>= 1; }
return res % MOD; }
void ntt(vector<long long> &a, bool invert) { int n = a.size();
if (n == 1) { return; }
vector<long long> Pe(n / 2), Po(n / 2);
for (int i = 0; 2 * i < n; i++) { Pe[i] = a[2 * i]; Po[i] = a[2 * i + 1]; }
ntt(Pe, invert); ntt(Po, invert);
long long wn = quickPower(invert ? G_INV : G, (MOD - 1) / n); long long w = 1;
for (int i = 0; i < n / 2; i++) { a[i] = Pe[i] + w * Po[i] % MOD; a[i] = (a[i] % MOD + MOD) % MOD; a[i + n / 2] = Pe[i] - w * Po[i] % MOD; a[i + n / 2] = (a[i + n / 2] % MOD + MOD) % MOD; w = w * wn % MOD; } }
|
迭代版(Iteration)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| public: static const long long MOD = 998244353; static const long long G = 3; static const int G_INV = 332748118; vector<int> rev;
long long quickPower(long long a, long long b) { long long res = 1;
while (b > 0) { if (b & 1) { res = (res * a) % MOD; }
a = (a * a) % MOD; b >>= 1; }
return res % MOD; }
void ntt(vector<long long> &a, bool invert = false) { int n = a.size();
for (int i = 0; i < n; i++) { if (i < rev[i]) { swap(a[i], a[rev[i]]); } }
for (int len = 2; len <= n; len <<= 1) { long long wlen = quickPower(invert ? G_INV : G, (MOD - 1) / len);
for (int i = 0; i < n; i += len) { long long w = 1; for (int j = 0; j < len / 2; j++) { long long u = a[i + j]; long long v = (w * a[i + j + len / 2]) % MOD; a[i + j] = (u + v) % MOD; a[i + j + len / 2] = (MOD + u - v) % MOD; w = (w * wlen) % MOD; } } }
if (invert) { long long inver = quickPower(n, MOD - 2); for (int i = 0; i < n; i++) { a[i] = (long long) a[i] * inver % MOD; } } }
|
复杂度分析
- 时间复杂度:O((m+n)log(m+n))。
- 空间复杂度:O(m+n)。
参考资料