Loading [MathJax]/extensions/TeX/newcommand.js
\newcommand{\ord}[1]{\mathcal{O}\left(#1\right)} \newcommand{\abs}[1]{\lvert #1 \rvert} \newcommand{\floor}[1]{\lfloor #1 \rfloor} \newcommand{\ceil}[1]{\lceil #1 \rceil} \newcommand{\opord}{\operatorname{\mathcal{O}}} \newcommand{\argmax}{\operatorname{arg\,max}} \newcommand{\str}[1]{\texttt{"#1"}}

2024年12月24日 星期二

[The radix-2 Cooley-Tukey FFT / FNTT Algorithm] 庫利-圖基 快速 (傅立葉/數論) 變換演算法

離散捲積 (Discrete Convolution)

給定兩個數列 A = (a_0, a_1, \dots a_{n-1}),\ B = (b_0, b_1, \dots, b_{m-1})
求兩數列的離散捲積 C = (c_0, c_1, \dots, c_{n+m-2}),其中 c_k = \sum_{i + j = k}a_ib_j

我們可以將數列轉換成多項式: 

\begin{align} A(x)=a_0+a_1x+a_2x^2+\dots+a_{n-1}x^{n-1}\\B(x)=b_0+b_1x+b_2x^2+\dots+b_{m-1}x^{m-1}\end{align}

這樣一來,c_i = (A * B)(x)x^i 項的係數,如果用最 naive 的做法,總共要花 O(n\times m) 的時間。
這裡的目標是要在 O((n+m)\log (n+m)) 的時間算出 C

多項式的表示法

係數表示法 Coefficient Representation

對於一個 n-1 次多項式 F(x)=a_0+a_1x+a_2x^2+\dots+a_{n-1}x^{n-1}
我們可以用 Coefficient Representation 來表示他:

F(x) := [a_0, a_1, \dots, a_{n - 1}]

點值表示法 Point-value Representation

除此之外,令 x_0, x_1, \dots, x_{n-1}n 個不同的數字,
我們也能用這些點在 F 中的取值來表示他

y_i = F(x_i),\ i = 0, 1, \dots, n-1
F(x):= [y_0, y_1, \dots, y_{n - 1}]

這種表示法又叫做 Point-value Representation

新的思路

給定 Coefficient Representation,我們現在只會 O(n\times m) 來做多項式乘法。
那如果換成 Point-value Representation 呢?

C(x_i) = A(x_i) \times B(x_i),~i = 0, 1, \dots, n+m-2

我們只要能抓 n+m-1 個不同數字的取值,最後一個對一個再相乘起來就好了!只需要 O(n+m) 的時間。

我們可以把計算多項式乘法的任務轉換成:

  1. 選擇 n+m-1 個不同的數字 X=(x_0, x_1, \dots, x_{n+m-2})
  2. 將原本是 Coefficient Representation 的多項式 A,B 轉為 Point-value Representation
  3. O(n+m) 的時間計算 C(x_i) = A(x_i) \times B(x_i),得到用 Point-value Representation 表示的多項式 C
  4. 將多項式 C 轉換回 Coefficient Representation

只要好好的選擇 X,就可以用分治法 (divide and conquer) 加速步驟 2, 4。

圖片與文字內容皆參考自 NTHUCPP FFT 單元
圖片與文字內容皆參考自 NTHUCPP FFT 單元

X 的選擇

n=2^r,r\ge 0 的時候,假設有個 \omega(n) 函數有以下性質:

  1. \omega(n)^0, \omega(n)^1,...,\omega(n)^{n-1} 皆為不同數值
  2. \omega(n)^n=1
  3. \omega(n)^{\frac{n}{2}}=-1,其實條件 1, 2 同時滿足的話這點會自動成立
  4. \omega(n)^2=\omega(\frac{n}{2})

X=(x_0, x_1, \dots, x_{n-1}),~x_i=\omega(n)^i 則原本是 Coefficient Representation 的多項式 F(x) := [a_0, a_1, \dots, a_{n - 1}] 其 Point-value Representation F(x) := [y_0, y_1, \dots, y_{n - 1}],~y_i=F(x_i)

透過 \omega(n) 函數的性質可以利用分治法在 O(n \log n) 的時間遞迴求出來。

n 不是 2 的冪次,我們可以找到一個 n'=2^r,n'>n,將 a_n,a_{n+1},...,a_{n'-1} 都設為 0,則 F'(x)=a_0+a_1x+...+a_{n'-1}x^{n'-1} 就能滿足使用分治法的條件。

分治法 (divide and conquer) 求 Point-value Representation

假設有個函數 DC(F, n) 輸入一個 n-1 次多項式 F(x) 的係數表示法,回傳 [F(\omega(n)^0), F(\omega(n)^1),..., F(\omega(n)^{n-1})]

\begin{align} G(x)=a_0+a_2x+a_4x^2+...+a_{n-2}x^{\frac{n}{2}-1}\\ H(x)=a_1+a_3x+a_5x^2+...+a_{n-1}x^{\frac{n}{2}-1} \end{align}F 的係數得到 H,G 的係數只需要 O(n) 的時間。

我們可以把 F(x)G(x)H(x) 表示:

F(x)=G(x^2)+x\times H(x^2) 透過 DC 函數可以遞迴得到 \begin{align} DC(H,n/2)&=[H(\omega(n/2)^0), H(\omega(n/2)^1),..., H(\omega(n/2)^{n/2-1})]\\ DC(G,n/2)&=[G(\omega(n/2)^0), G(\omega(n/2)^1),..., G(\omega(n/2)^{n/2-1})] \end{align}

對於 0\le k<\frac{n}{2},透過性質 2, 3, 4 可以知道:

\begin{align} F(\omega(n)^k)&=G(\omega(n)^{2k})+\omega(n)^k\times H(\omega(n)^{2k})\\ &=G(\omega(n/2)^k)+\omega(n)^k\times H(\omega(n/2)^k)\\ \\ F(\omega(n)^{\frac{n}{2}+k})&=G(\omega(n)^{n+2k})+\omega(n)^{\frac{n}{2}+k}\times H(\omega(n)^{n+2k})\\ &=G(\omega(n)^{2k})-\omega(n)^k\times H(\omega(n)^{2k}) \\ &=G(\omega(n/2)^k)-\omega(n)^k\times H(\omega(n/2)^k) \end{align}

這樣有了 DC(H,n/2),DC(G,n/2) 就可以在 O(n) 的時間做出 DC(F,n) 的結果。得到遞迴的時間複雜度 T(n)=O(n) + 2T(n/2) + O(n) = O(n\log n)

舉例來說 n=8

F(x)=a_0+a_1x+a_2x^2+...+a_7x^7

\begin{align} G(x)=a_0+a_2x+a_4x^2+a_6x^3\\ H(x)=a_1+a_3x+a_5x^2+a_7x^3 \end{align}

想要用遞迴方法求出 F(\omega(8)^0), F(\omega(8)^1),..., F(\omega(8)^7)

首先可以遞迴求出

\begin{align} G(\omega(4)^0), G(\omega(4)^1), G(\omega(4)^2), G(\omega(4)^3)\\ H(\omega(4)^0), H(\omega(4)^1), H(\omega(4)^2), H(\omega(4)^3) \end{align}

接著可以在 O(n) 得到:

  • 用加的
    • F(\omega(8)^0)=G(\omega(4)^0)+\omega(8)^0\times H(\omega(4)^0)
    • F(\omega(8)^1)=G(\omega(4)^1)+\omega(8)^1\times H(\omega(4)^1)
    • F(\omega(8)^2)=G(\omega(4)^2)+\omega(8)^2\times H(\omega(4)^2)
    • F(\omega(8)^3)=G(\omega(4)^3)+\omega(8)^3\times H(\omega(4)^3)
  • 用減的
    • F(\omega(8)^4)=G(\omega(4)^0)-\omega(8)^0\times H(\omega(4)^0)
    • F(\omega(8)^5)=G(\omega(4)^1)-\omega(8)^1\times H(\omega(4)^1)
    • F(\omega(8)^6)=G(\omega(4)^2)-\omega(8)^2\times H(\omega(4)^2)
    • F(\omega(8)^7)=G(\omega(4)^3)-\omega(8)^3\times H(\omega(4)^3)

程式碼的部分等講完逆變換後在介紹。 

逆變換

(y_0, y_1, \dots, y_{n - 1}),~y_i=F(x_i),令多項式 Z(x)=y_0+y_1x+y_2x^2+y_{n-1}x^{n-1},也就是將 F(x) 的 Point-value Representation 作為多項式 Z(x) 的 Coefficient Representation。

\omega(n)^k 帶入 Z(x) 可以發現

\begin{align} Z(\omega(n)^k)&=\sum_{i=0}^{n-1} F(\omega(n)^i)\omega(n)^{ik} \\ &=\sum_{i=0}^{n-1} \left(\left(\sum_{j=0}^{n-1} a_j\omega(n)^{ij}\right)\omega(n)^{ik}\right)\\ &=\sum_{j=0}^{n-1}a_j\left(\sum_{i=0}^{n-1} \left(\omega(n)^{j+k}\right)^i\right) \end{align}

這裡等比數列的和只有兩種可能

\sum_{i=0}^{n-1} (\omega(n)^{j+k})^i = \left\{ \begin{aligned} &n&,&~~~j+k\equiv 0\ (mod\ n) \\ &\frac{\omega(n)^{n(j+k)}-1}{\omega(n)^{j+k} -1} = 0&, &~~~\text{else} \end{aligned} \right.

因此得到結論

  1. Z(\omega(n)^0)=a_0\times n
  2. Z(\omega(n)^k)=a_{n-k}\times n, ~~~0<k<n

這表示我們可以將 y_0\sim y_{n-1} 使用同樣的分治法輕鬆地在 O(n \log n) 得到原本多項式 F(x) 的係數 a_0\sim a_{n-1}

遞迴版本程式碼

#include <algorithm>
#include <cassert>
#include <cstddef>
template <typename T, typename Policy>
class CooleyTukeyAlgorithmRecursive {
public:
using policy = Policy;
using vector_type = typename Policy::vector_type;
private:
// Input: 係數表示法 F(x) := [f[0], f[1], ..., f[n-1]]
// Output: 點值表示法 F(x) := [fY[0], fY[1], ..., fY[n-1]], fY[i] = F(w(n)^i)
auto divide_and_conquer(vector_type f) {
size_t n = f.size();
if (n <= 1) return f;
vector_type g(n / 2), h(n / 2);
for (size_t i = 0; i < n; ++i) { // 根據奇偶分類
if (i % 2 == 0)
g[i / 2] = f[i];
else
h[i / 2] = f[i];
}
auto gY = divide_and_conquer(g); // 得到 gY[i] = G(w(n/2)^i)
auto hY = divide_and_conquer(h); // 得到 hY[i] = H(w(n/2)^i)
vector_type fY(n);
auto wn = Policy::omega(n), wk = Policy::one();
for (size_t k = 0; k < n / 2; ++k) {
auto u = gY[k], t = Policy::mul(wk, hY[k]);
fY[k] = Policy::add(u, t);
fY[k + n / 2] = Policy::sub(u, t);
wk = Policy::mul(wk, wn);
}
return fY; // 得到 fY[i] = F(w(n)^i)
}
public:
auto run(const vector_type& in, bool is_inv) {
size_t N = in.size();
assert((N & (N - 1)) == 0 && Policy::check(N)); // N 必須是 2 的冪次
auto out = divide_and_conquer(in);
if (is_inv) { // 逆變換
std::reverse(out.begin() + 1, out.end());
auto inv_N = Policy::inverse(N);
for (size_t i = 0; i < N; ++i) out[i] = Policy::mul(out[i], inv_N);
}
return out;
}
};

由於我們還不知道 \omega(n) 究竟是個怎樣的函數,實作使用 template 的方式,使用者要將與 \omega(n) 有關的操作寫成 class 後填入 `Policy` 這個欄位。


\omega(n) 的選擇

可以觀察到 \omega(n)^k 有非常明顯的循環性質,這在一般人常見的實數領域中很少見,有這種性質的東西經常出現在:

  1. 複數運算的單位根
  2. 同餘運算下的有限體 (finite field)

快速傅立葉變換 (Fast Fourier Transform, FFT)

\omega(n)=e^{i\frac{2\pi}{n}}。透過 Euler's formula 可以知道 e^{i\frac{2\pi}{n}}=\cos(\frac{2\pi}{n})+i\sin(\frac{2\pi}{n})

這樣 \omega(n) 的數學含意就是複數的 n單位根

  1. \omega(n)^0, \omega(n)^1,...,\omega(n)^{n-1} 的值皆不相同
  2. \omega(n)^n=e^{i\times 2\pi}=1
  3. \omega(n)^{\frac{n}{2}}=e^{i\pi}=-1
  4. \omega(n)^2=e^{i\frac{2\times 2\pi}{n}}=e^{i\frac{2\pi}{n/2}}=\omega(\frac{n}{2})

複數以及 `exp` 函數都是 C++ STL 有提供的東西:

#include <cmath>
#include <complex>
#include <vector>
template <typename T, typename ComplexTy = std::complex<T>>
struct FFT_Policy {
using vector_type = std::vector<ComplexTy>;
static constexpr T pi = std::acos((T)-1);
static bool check(size_t N) { return true; }
static auto one() { return ComplexTy(1, 0); }
static auto omega(size_t N) {
return std::exp(ComplexTy(0, 2 * pi / N));
}
static auto inverse(T value) { return T(1) / value; }
static auto add(ComplexTy a, ComplexTy b) { return a + b; }
static auto sub(ComplexTy a, ComplexTy b) { return a - b; }
static auto mul(ComplexTy a, ComplexTy b) { return a * b; }
};
view raw FFT_policy.cpp hosted with ❤ by GitHub

不過使用 FFT 計算多項式乘法會產生浮點數誤差,因此有些人會考慮使用待會會介紹的 FNTT

快速數論變換 (Fast Number-Theoretic Transform, FNTT)

\omega(n)=g^{\frac{P-1}{n}}\mod P,這裡的 P 是滿足某性質的質數且 g\mod P 的原根。因此首先我們要來認識什麼是原根。

什麼是原根

假設 g, m 互質, 使得 g^d \equiv 1\ (mod\ m) 成立的最小正整數 d 定義為 \delta_m(g)

根據歐拉定理 \delta_m(g)|\phi(m),若 \delta_m(g) = \phi(m) ,則稱 g\mod m 的原根 (primitive root)。

如果 m 是個質數,則最小的 g 通常是個很小的數字 (g\ll P^{5/\log\log P} by Least Prime Primitive Roots),zerojudge 上剛好有一題 [b435. 尋找原根]

對於任意質數 P>2 其原根 g 有一些直觀的性質:

  1. \phi(P)=P-1,~g^{\phi(P)}\equiv g^{P-1}\equiv 1\ (mod\ P),這其實就是費馬小定理
  2. g^1,...,g^{P-2},g^{P-1}\mod P 的結果皆不相同,這是原根本來的性質
  3. g^{(P-1)/2}\equiv -1\ (mod\ P) ,由性質 1,2 可以得到

如何選擇質數 P

P-1 可以被 n 整除,則所有 \omega(n) 的性質都能滿足(所有運算皆是同餘運算):

  1. \omega(n)^0, \omega(n)^1,...,\omega(n)^{n-1} 的值皆不相同
  2. \omega(n)^n=g^{\frac{P-1}{n}n}=g^{P-1}= 1
  3. \omega(n)^{\frac{n}{2}}=g^{(P-1)/2}=-1
  4. \omega(n)^2=g^{\frac{2(P-1)}{n}}=g^{\frac{P-1}{n/2}}=\omega(\frac{n}{2})

為了滿足 P-1 可以被 n 整除,因為 n 是 2 的冪次,FNTT 需要一個特殊構造的質數 P=r\times 2^k+1,~2^k\ge n,已經有中國人整理出一些常用的質數:

P=998244353=7\times 17\times 2^{23}+1 是個經常被使用的質數,其原根 g=3

這樣我們就可以輕鬆地根據定義寫出 FNTT 的實作:

// It is recommended that T at least be long long
#include <vector>
template <typename T, T P, T G>
struct NTT_Policy {
using vector_type = std::vector<T>;
static T pow_mod(T n, T k, T m) {
T ans = 1;
for (n %= m; k; k >>= 1) {
if (k & 1) ans = ans * n % m;
n = n * n % m;
}
return ans;
}
static bool check(size_t N) { return N <= 1 || P % N == 1; }
static auto one() { return T(1); }
static auto omega(size_t N) {
return pow_mod(G, (P - 1) / N, P);
}
static auto inverse(T value) { return pow_mod(value, P - 2, P); }
static auto add(T a, T b) { return (a + b) % P; }
static auto sub(T a, T b) { return ((a - b) % P + P) % P; }
static auto mul(T a, T b) { return a * b % P; }
};
view raw NTT_policy.cpp hosted with ❤ by GitHub

注意 FNTT 的所有運算皆是同餘運算,也就是說 FNTT 的計算多項式乘法的結果是原本的數字\mod P 的值,因此若需要得到精確的結果需要用不同質數執行多次 FNTT 使用中國剩餘定理將結果合併。

假設有個 n-1 次多項式要和一個 m-1 次多項式做乘法,這兩個多項式的所有係數皆小於一個正整數 q

那麼這樣任何多項式係數的範圍就是 [0,q-1],係數兩兩相乘不會超過 (q-1)^2,一共最多 \min(n,m) 項相加,不會超過 \min(n,m)\times(q-1)^2

我們可以選 k 個可以進行 FNTT 的不同質數使得以下條件成立:

\prod_{i=1}^{k}p_i>\min(n,m)\times(q-1)^2

這樣分別使用這些質數執行 FNTT 後再使用中國剩餘定理將結果合併就可以得到完全精確的係數,但要注意計算範圍可能會超過 `long long`,甚至有可能會需要 `__int128_t`。

非遞迴版 Cooley-Tukey Algorithm

我們將係數遞迴的狀況畫出來,注意到葉節點係數的順序會是 (0, 4, 2, 6, 1, 5, 3, 7)

觀察這棵樹,由上往下的第 i 次分層時,是按照其 index 在第 i 個 bit 的奇偶分兩邊的,並且第 i 次分層會決定其最後位子的第 \log_2 n - i - 1 個 bit。

可以推論出,index i 的換置後的位子就會將是 i 的 binary representation 給 reverse。

Reverse Bit 的方法

遞推建表法,建立 O(n) 大小表,總時間複雜度也是 O(n)

auto get_bit_reverse_table(size_t N){
std::vector<size_t> table(N, 0);
for(size_t i = 1; i < N; ++i){
table[i] = table[i >> 1] >> 1;
if(i & 1) table[i] += N >> 1;
}
return table;
}

直接換置法,一次反轉一個數字 n,只要 O(1) 空間,但時間複雜度是 O(\log\log n)

#include <type_traits>
template <typename T>
inline T reverse_bits(T n) {
using unsigned_T = typename std::make_unsigned<T>::type;
unsigned_T v = (unsigned_T)n;
v = ((v & 0xAAAAAAAAAAAAAAAA) >> 1) | ((v & 0x5555555555555555) << 1);
v = ((v & 0xCCCCCCCCCCCCCCCC) >> 2) | ((v & 0x3333333333333333) << 2);
v = ((v & 0xF0F0F0F0F0F0F0F0) >> 4) | ((v & 0x0F0F0F0F0F0F0F0F) << 4);
if constexpr (sizeof(T) == 1) return v;
v = ((v & 0xFF00FF00FF00FF00) >> 8) | ((v & 0x00FF00FF00FF00FF) << 8);
if constexpr (sizeof(T) == 2) return v;
v = ((v & 0xFFFF0000FFFF0000) >> 16) | ((v & 0x0000FFFF0000FFFF) << 16);
if constexpr (sizeof(T) <= 4)
return v;
else
v = ((v & 0xFFFFFFFF00000000) >> 32) | ((v & 0x00000000FFFFFFFF) << 32);
return (T)v;
}

如果 index i 的位置是 j,那麼 index j 的位置也會是 i

想要節省空間的話,可以考慮用直接換置法 in-place 進行換置:

template<typename VectorTy>
void displacement(VectorTy &V){
size_t N = V.size();
for(int i = 0; i < N; ++i){
size_t rev_i = reverse_bits(i) >> (sizeof(i) * 8 - N);
if(i < rev_i) std::swap(V[i], V[rev_i]);
}
}
view raw in_place.cpp hosted with ❤ by GitHub

蝶形網路 Butterfly Diagram

我們一開始就把係數的順序透過 bit reverse 換置,可以寫出非遞迴版本的程式碼:

#include <algorithm>
#include <cassert>
#include <cstddef>
template <typename T, typename Policy>
class CooleyTukeyAlgorithm {
size_t reverse_bits_len(size_t N, size_t len) {
return ::reverse_bits(N) >> (sizeof(N) * 8 - len);
}
public:
using policy = Policy;
using vector_type = typename Policy::vector_type;
auto run(const vector_type& in, bool is_inv) {
size_t N = in.size();
assert((N & (N - 1)) == 0 && Policy::check(N));
vector_type out(N);
for (size_t i = 0; i < N; ++i)
out[reverse_bits_len(i, std::__lg(N))] = in[i];
for (size_t step = 2; step <= N; step *= 2) {
auto wn = Policy::omega(step), wk = Policy::one();
const size_t helf_step = step / 2;
for (size_t i = 0; i < helf_step; ++i) {
for (size_t k = i; k < N; k += step) {
size_t j = k + helf_step;
auto u = out[k], t = Policy::mul(wk, out[j]);
out[k] = Policy::add(u, t);
out[j] = Policy::sub(u, t);
}
wk = Policy::mul(wk, wn);
}
}
if (is_inv) {
std::reverse(out.begin() + 1, out.end());
auto inv_N = Policy::inverse(N);
for (size_t i = 0; i < N; ++i) out[i] = Policy::mul(out[i], inv_N);
}
return out;
}
};

將計算流程畫成圖形,可以看到有很多長得像蝴蝶的形狀,因此被稱之為蝶形網路:

離散捲積程式碼

template <typename AlgorithmTy>
auto convolution(typename AlgorithmTy::vector_type A,
typename AlgorithmTy::vector_type B) {
using Policy = typename AlgorithmTy::policy;
using vector_type = typename AlgorithmTy::vector_type;
if (A.empty() || B.empty()) return vector_type{};
size_t C_size = A.size() + B.size() - 1, N = C_size;
while (N & (N - 1)) ++N;
A.resize(N), B.resize(N);
A = AlgorithmTy().run(A, false), B = AlgorithmTy().run(B, false);
vector_type C(N);
for (size_t i = 0; i < N; ++i) C[i] = Policy::mul(A[i], B[i]);
C = AlgorithmTy().run(C, true);
C.resize(C_size);
return C;
}
view raw convolution.cpp hosted with ❤ by GitHub

測試程式碼

#include <initializer_list>
#include <iostream>
template <typename ValueTy>
auto naiveMethod(std::vector<ValueTy> A, std::vector<ValueTy> B) {
if (A.empty() || B.empty()) return std::vector<ValueTy>{};
std::vector<ValueTy> C(A.size() + B.size() - 1);
for (size_t i = 0; i < A.size(); ++i) {
for (size_t j = 0; j < B.size(); ++j) {
C[i + j] += A[i] * B[j];
}
}
return C;
}
template <typename AlgorithmTy>
void test(typename AlgorithmTy::vector_type A,
typename AlgorithmTy::vector_type B) {
auto Res = convolution<AlgorithmTy>(A, B);
for (auto x : Res) std::cout << x << ' ';
std::cout << std::endl;
}
int main() {
std::cout << std::fixed;
std::cout.precision(1);
using NTT =
CooleyTukeyAlgorithm<long long,
NTT_Policy<long long, (1 << 23) * 7 * 17 + 1, 3>>;
test<NTT>({1, 2, 3, 4}, {5, 6, 7, 8, 9});
using FFT = CooleyTukeyAlgorithm<double, FFT_Policy<double>>;
test<FFT>({1, 2, 3, 4}, {5, 6, 7, 8, 9});
auto C = naiveMethod<long long>({1, 2, 3, 4}, {5, 6, 7, 8, 9});
for (auto x : C) std::cout << x << ' ';
std::cout << std::endl;
return 0;
}
view raw test.cpp hosted with ❤ by GitHub

Output:

5 16 34 60 70 70 59 36
(5.0,0.0) (16.0,0.0) (34.0,0.0) (60.0,-0.0) (70.0,-0.0) (70.0,-0.0) (59.0,-0.0) (36.0,0.0)
5 16 34 60 70 70 59 36