可恶啊,又让我想起这玩意儿了,然后又忘记怎么推了,只能回去查一查了。其实我困扰的是,CNN的卷积为啥叫卷积啊,卷积不是h(t)=k∑f(k)g(t−k)吗,那个卷积核,分明可以直接对应元素相乘吗,网上有些图居然也是直接乘了。我以为卷积也有什么快速算法啊,可是,普通的卷积又不具有傅里叶变换的性质。
其实傅里叶变换也忘得差不多了,不过,就先这样推着吧。
参考:The Fast Fourier Transform and its Applications

Notation
WN=exp{N2πi}
X(j)=n=0∑N−1A(n)WNjn
A(n)=N1n=0∑N−1X(j)WN−jn
性质
WNN=1,WNj+N=WNj
WNj=WNjmodN
X(j)=X(jmodN)
A(n)=A(nmodN)
总而言之,有个周期N
n=0∑N−1WNnjWNmj=Nifn=mmodN否则0
FFT推导
假设N=r×s,那么存在j1,j0,对于任意的j可用下列式子表示:
j=j1r+j0j1=0,1,…,s−1,j0=0,1,…,r−1
同样,对于n来说,存在n1,n0:
n=n1s+n0n1=0,1,…,r−1,n0=0,1,…,s−1
WNjn=WNj1n1rsWNj1rn0WNj0n1sWNj0n0=Wsj1n0Wrj0n1WNj0n0
X(j)=X(j1,j0)=n0=0∑s−1n1=0∑r−1A(n1,n0)Wrj0n1WNj0n0Wsj1n0
令:A1(j0,n0)=WNj0n0n1=0∑r−1A(n1,n0)Wrj0n1
X(j1,j0)=n0=0∑s−1A1(n1,n0)Wsj1n0
让我们来看看这第一次分解所消耗的计算量,计算一个A1(j0,n0)消耗r次,那么计算全部的A1就是rsr=Nr次。在A1全部知道的情况下,计算一个X(j1,j0)是s次,计算所有的X需要rss=Ns次,故本次分解消耗N(r+s)次计算。如果不分解的话,大概是N2级别。
更加恐怖的是这是第一次分解,可以看到,X(j1,j0)成了周期为s的傅里叶变换,A1(j0,n0)成了周期为r的傅里叶变换,所以,如果,r,s还能够分解的话,计算量还能进一步减少。
假设:N=rm,s=rm−1,那么,计算A1总共消耗Nr=rm+1次,这时X变为周期为s=rm−1的傅里叶变换,可以预见,将s分解为rm−2×r,计算A2应当消耗rm次,但第二次需要计算r个这种情况(因为X(j1,j0)总共有N个,分成了r个周期为s的傅里叶变换),所以第二次消耗的计算量也为r^{m+1},以此类推,最后结果为:
m∗rm+1=mNr
又N=rm,所以m=logrN,多么牛啊,原本二次的计算量近似线性了!
一般情况下,N=r1×r2×⋯rm,最后的计算量为N(r1+r2+…+rm)
论文里的推导过程

代码
import numpy as np
import time
from scipy.fftpack import fft, ifft
def number_fc(N):
for i in range(2, int(np.sqrt(N)) + 1):
if N % i == 0:
s = i
r = N / i
return int(s), int(r)
return 0, 0
def conv(x, k):
N = len(x)
w = np.array([np.exp(-2 * i * k * np.pi * 1j / N) for i in range(N)])
return x @ w
def w_s_N(j0, j1, s, N):
w_s_j1 = np.array([np.exp(-2 * n0 * j1 * np.pi * 1j / s) for n0 in range(s)])
w_N_n0 = np.array([np.exp(-2 * n0 * j0 * np.pi * 1j / N) for n0 in range(s)])
return w_s_j1 * w_N_n0
def FFT(x):
N = len(x)
s, r = number_fc(N)
if not s:
return np.array([conv(x, k) for k in range(N)])
else:
A0 = np.zeros(N, dtype=complex)
A1 = np.array([FFT(x[n0::s]) for n0 in range(s)])
for j1 in range(s):
for j0 in range(r):
A0[j1 * r + j0] = A1[:, j0] @ w_s_N(j0, j1, s, N)
return A0
测试代码:
N = 4096
x = np.arange(N)
t1 = time.time()
y1 = FFT(x)
t2 = time.time()
print(t2 - t1)
t1 = time.time()
y2 = [conv(x, k) for k in range(N)]
t2 = time.time()
print(t2-t1)
t1 = time.time()
y3 = fft(x)
t2 = time.time()
print(t2 - t1)
从上面可以看到,当N = 4096的时候,二者的差距已经十分明显了。第三个方案是,scipy库里面的,即便N都这么大了,依然动不了他分毫。大概是我写代码的水平还是太low了吧。
在写代码的时候,对于时间复杂的计算也有了新的认识,不是那么想当然的,果然实践才是检验真理的唯一标准。