[Leetcode][43. 字符串相乘] 大数乘法的快速傅里叶变换(FFT) 和 快速数论变换(NTT)解法

By Long Luo

Leetcode 43. 字符串相乘 其实就是大数乘法,常规的大数方法可以参考 超大数字的四则运算是如何实现的呢? ,还可以使用快速傅里叶变换 (FFT)(\textit{FFT}) 和快速数论变换 (NTT)(\textit{NTT}) 实现。

快速傅里叶变换(FFT)

快速傅里叶变换 (FFT)(\textit{FFT}) 详细解释可以参考这几篇文章:

快速傅里叶变换(FFT)算法
快速傅里叶变换(FFT)算法的实现及优化

下面分别给出递归版迭代版代码实现:

递归(Recursion)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class Solution {
public:
const double PI = acos(-1.0); // PI = arccos(-1)

struct Complex {
double re, im;

Complex(double _re = 0.0, double _im = 0.0) {
re = _re;
im = _im;
}

inline void real(const double &re) {
this->re = re;
}

inline double real() {
return re;
}

inline void imag(const double &im) {
this->im = im;
}

inline double imag() {
return im;
}

inline Complex operator-(const Complex &other) const {
return Complex(re - other.re, im - other.im);
}

inline Complex operator+(const Complex &other) const {
return Complex(re + other.re, im + other.im);
}

inline Complex operator*(const Complex &other) const {
return Complex(re * other.re - im * other.im, re * other.im + im * other.re);
}

inline void operator/(const double &div) {
re /= div;
im /= div;
}

inline void operator*=(const Complex &other) {
*this = Complex(re * other.re - im * other.im, re * other.im + im * other.re);
}

inline void operator+=(const Complex &other) {
this->re += other.re;
this->im += other.im;
}

inline Complex conjugate() {
return Complex(re, -im);
}
};

/**
* FFT Recursion 实现
*
* @param a
* @param invert true means IFFT, else FFT
* @return im
*/
vector<Complex> FFT(vector<Complex> &a, bool invert) {
//第一个参数为一个多项式的系数, 以次数从小到大的顺序, 向量中每一项的实部为该项系数
int n = a.size();

// 如果当前多项式仅有常数项时直接返回多项式的值
if (n == 1) {
return a;
}

vector<Complex> Pe(n / 2), Po(n / 2); // 文中的Pe与Po的系数表示法

for (int i = 0; 2 * i < n; i++) {
Pe[i] = a[2 * i];
Po[i] = a[2 * i + 1];
}

// Divide 分治
// 递归求 ye = Pe(xi), yo = Po(xi)
vector<Complex> ye = FFT(Pe, invert);
vector<Complex> yo = FFT(Po, invert);

// Combine
vector<Complex> y(n);

// Root of Units
double ang = 2 * PI / n * (invert ? -1 : 1);
Complex wn(cos(ang), sin(ang)); // wn为第1个n次复根,
Complex w(1, 0); // w为第零0个n次复根, 即为 1

for (int i = 0; i < n / 2; i++) {
y[i] = ye[i] + w * yo[i]; // 求出P(xi)
y[i + n / 2] = ye[i] - w * yo[i]; // 由单位复根的性质可知第k个根与第k + n/2个根互为相反数
w = w * wn; // w * wn 得到下一个复根
}

return y; // 返回最终的系数
}

string multiply(string num1, string num2) {
if (num1 == "0" || num2 == "0") {
return "0";
}

int len1 = num1.size();
int len2 = num2.size();

int n = 1;
while (n < len1 + len2) {
n = n << 1;
}

vector<Complex> a(n);
vector<Complex> b(n);

for (int i = len1 - 1; i >= 0; i--) {
a[i] = Complex(num1[len1 - 1 - i] - '0', 0);
}

for (int i = len2 - 1; i >= 0; i--) {
b[i] = Complex(num2[len2 - 1 - i] - '0', 0);
}

a = FFT(a, false);
b = FFT(b, false);

for (int i = 0; i < n; i++) {
a[i] = a[i] * b[i];
}

a = FFT(a, true);

string ans;
int carry = 0;
for (int i = 0; i < n; i++) {
int sum = round(round(a[i].re) / n) + carry;
carry = sum / 10;
ans += sum % 10 + '0';
}

if (carry > 0) {
ans += carry % 10 + '0';
}

int idx = ans.size() - 1;
while (ans[idx] == '0' && idx > 0) {
idx--;
}

ans = ans.substr(0, idx + 1);
reverse(ans.begin(), ans.end());
return ans;
}
}

迭代(Iteration)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class Solution {
public:
const double PI = acos(-1.0); // PI = arccos(-1)

struct Complex {
double re, im;

Complex(double _re = 0.0, double _im = 0.0) {
re = _re;
im = _im;
}

inline void real(const double &re) {
this->re = re;
}

inline double real() {
return re;
}

inline void imag(const double &im) {
this->im = im;
}

inline double imag() {
return im;
}

inline Complex operator-(const Complex &other) const {
return Complex(re - other.re, im - other.im);
}

inline Complex operator+(const Complex &other) const {
return Complex(re + other.re, im + other.im);
}

inline Complex operator*(const Complex &other) const {
return Complex(re * other.re - im * other.im, re * other.im + im * other.re);
}

inline void operator/(const double &div) {
re /= div;
im /= div;
}

inline void operator+=(const Complex &other) {
this->re += other.re;
this->im += other.im;
}

inline void operator-=(const Complex &other) {
this->re -= other.re;
this->im -= other.im;
}

inline void operator*=(const Complex &other) {
*this = Complex(re * other.re - im * other.im, re * other.im + im * other.re);
}

inline Complex conjugate() {
return Complex(re, -im);
}
};

static const int N = 256;

Complex omega[N];
Complex invert[N];

int rev[N];

void init(int n) {
rev[0] = 0;

for (int i = 0; i < n; i++) {
double ang = 2 * PI * i / n;
omega[i] = Complex(cos(ang), sin(ang));
invert[i] = omega[i].conjugate();

if (i > 0) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1) {
rev[i] |= n >> 1;
}
}
}
}

/**
* FFT Iteration 实现
*
* @param a
* @param invert true means IFFT, else FFT
* @return y
*/
void FFT(vector<Complex> &a, Complex *omega) {
//第一个参数为一个多项式的系数, 以次数从小到大的顺序, 向量中每一项的实部为该项系数
int n = a.size();

// 如果当前多项式仅有常数项时直接返回多项式的值
if (n == 1) {
return;
}

for (int i = 0; i < n; ++i) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}

for (int len = 2; len <= n; len *= 2) {
for (int i = 0; i < n; i += len) {
for (int j = 0; j < len / 2; j++) {
Complex u = a[i + j];
Complex v = omega[j * n / len] * a[i + j + len / 2];
a[i + j] = u + v;
a[i + j + len / 2] = u - v;
}
}
}
}

string multiply(string num1, string num2) {
if (num1 == "0" || num2 == "0") {
return "0";
}

int len1 = num1.size();
int len2 = num2.size();

int n = 1;
while (n < len1 + len2) {
n = n << 1;
}

vector<Complex> a(n);
vector<Complex> b(n);

for (int i = len1 - 1; i >= 0; i--) {
a[i].real((num1[len1 - 1 - i] - '0'));
}

for (int i = len2 - 1; i >= 0; i--) {
b[i].real((num2[len2 - 1 - i] - '0'));
}

init(n);

FFT(a, omega);
FFT(b, omega);

for (int i = 0; i < n; i++) {
a[i] = a[i] * b[i];
}

FFT(a, invert);

string ans;
int carry = 0;
for (int i = 0; i < n; i++) {
int sum = round(round(a[i].real()) / n) + carry;
carry = sum / 10;
ans += sum % 10 + '0';
}

if (carry > 0) {
ans += carry % 10 + '0';
}

int idx = n - 1;
while (ans[idx] == '0' && idx > 0) {
idx--;
}

ans = ans.substr(0, idx + 1);
reverse(ans.begin(), ans.end());
return ans;
}
}

复杂度分析

  • 时间复杂度O((m+n)log(m+n))O((m+n)log(m+n))
  • 空间复杂度O(m+n)O(m+n)

快速数论变换(Number Theoretic Transform)解法

快速数论变换 (NTT)(\textit{NTT}) 详细解释可以参考这篇文章:快速数论变换(Number Theoretic Transform)

下面给出递归版迭代版的实现:

递归(Recursion)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class Solution {

public:
const long long G = 3;
const long long G_INV = 332748118;
const long long MOD = 998244353;

vector<int> rev;

long long quickPower(long long a, long long b) {
long long res = 1;

while (b > 0) {
if (b & 1) {
res = (res * a) % MOD;
}

a = (a * a) % MOD;
b >>= 1;
}

return res % MOD;
}

void ntt(vector<long long> &a, bool invert) {
int n = a.size();

if (n == 1) {
return;
}

vector<long long> Pe(n / 2), Po(n / 2);

for (int i = 0; 2 * i < n; i++) {
Pe[i] = a[2 * i];
Po[i] = a[2 * i + 1];
}

ntt(Pe, invert);
ntt(Po, invert);

long long wn = quickPower(invert ? G_INV : G, (MOD - 1) / n);
long long w = 1;

for (int i = 0; i < n / 2; i++) {
a[i] = Pe[i] + w * Po[i] % MOD;
a[i] = (a[i] % MOD + MOD) % MOD;
a[i + n / 2] = Pe[i] - w * Po[i] % MOD;
a[i + n / 2] = (a[i + n / 2] % MOD + MOD) % MOD;
w = w * wn % MOD;
}
}

string multiply(string num1, string num2) {
if (num1 == "0" || num2 == "0") {
return "0";
}

int len1 = num1.size();
int len2 = num2.size();

int n = 1;

while (n < (len1 + len2)) {
n = n << 1;
}

vector<long long> a(n, 0), b(n, 0);

for (int i = 0; i < len1; ++i) {
a[i] = num1[len1 - 1 - i] - '0';
}

for (int i = 0; i < len2; ++i) {
b[i] = num2[len2 - 1 - i] - '0';
}

ntt(a, false);
ntt(b, false);

for (int i = 0; i < n; i++) {
a[i] = (a[i] * b[i]) % MOD;
}

ntt(a, true);

string res;
long long carry = 0;
long long inver = quickPower(n, MOD - 2);

for (int i = 0; i < n; i++) {
a[i] = a[i] * inver % MOD;
}

for (int i = 0; i < n; i++) {
long long sum = a[i] + carry;
res += sum % 10 + '0';
carry = sum / 10;
}

while (carry) {
res += carry % 10 + '0';
carry /= 10;
}

int idx = n - 1;
while (idx >= 0 && res[idx] == '0') {
idx--;
}

res = res.substr(0, idx + 1);
reverse(res.begin(), res.end());
return res;
}
}

迭代(Iteration)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
class Solution {
static const long long MOD = 998244353;
static const long long G = 3;
static const int G_INV = 332748118;
vector<int> rev;

public:
long long quickPower(long long a, long long b) {
long long res = 1;

while (b > 0) {
if (b & 1) {
res = (res * a) % MOD;
}

a = (a * a) % MOD;
b >>= 1;
}

return res % MOD;
}

void ntt(vector<long long> &a, bool invert = false) {
int n = a.size();

for (int i = 0; i < n; i++) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}

for (int len = 2; len <= n; len <<= 1) {
long long wlen = quickPower(invert ? G_INV : G, (MOD - 1) / len);

for (int i = 0; i < n; i += len) {
long long w = 1;
for (int j = 0; j < len / 2; j++) {
long long u = a[i + j];
long long v = (w * a[i + j + len / 2]) % MOD;
a[i + j] = (u + v) % MOD;
a[i + j + len / 2] = (MOD + u - v) % MOD;
w = (w * wlen) % MOD;
}
}
}

if (invert) {
long long inver = quickPower(n, MOD - 2);
for (int i = 0; i < n; i++) {
a[i] = (long long) a[i] * inver % MOD;
}
}
}

string multiply(string num1, string num2) {
if (num1 == "0" || num2 == "0") {
return "0";
}

int len1 = num1.size();
int len2 = num2.size();

int n = 1;
int bit = 1;

while ((n <<= 1) < (len1 + len2)) {
++bit;
}

rev.resize(n);
for (int i = 0; i < n; i++) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

vector<long long> a(n, 0), b(n, 0);

for (int i = 0; i < len1; ++i) {
a[i] = num1[len1 - 1 - i] - '0';
}

for (int i = 0; i < len2; ++i) {
b[i] = num2[len2 - 1 - i] - '0';
}

ntt(a);
ntt(b);

for (int i = 0; i < n; i++) {
a[i] = (a[i] * b[i]) % MOD;
}

ntt(a, true);

string res;
long long carry = 0;
for (int i = 0; i < len1 + len2 - 1; ++i) {
long long curr = a[i] + carry;
res += curr % 10 + '0';
carry = curr / 10;
}

while (carry) {
res += carry % 10 + '0';
carry /= 10;
}

reverse(res.begin(), res.end());
return res;
}
}

复杂度分析

  • 时间复杂度O((m+n)log(m+n))O((m+n)log(m+n))
  • 空间复杂度O(m+n)O(m+n)

All suggestions are welcome.
If you have any query or suggestion please comment below.
Please upvote👍 if you like💗 it. Thank you:-)

Explore More Leetcode Solutions. 😉😃💗