$$`
\left[
\begin{matrix}
A ,B\
C,D
\end{matrix}
\right]
\times
\left[
\begin{matrix}
E,F\
G,H
\end{matrix}
\right]
\left[
\begin{matrix}
AE+BG, AF+BH\
CE+DG,CF+DH
\end{matrix}
\right]
`$$
Strassen算法於1969年由德國數學家Strassen提出,該方法引入七個中間變量,每個中間變量都只需要進行一次乘法運算。而朴素算法卻需要進行8次乘法運算。
原理
Strassen算法的原理如下所示,使用sympy驗證Strassen算法的正確性
import sympy as s
A = s.Symbol("A")
B = s.Symbol("B")
C = s.Symbol("C")
D = s.Symbol("D")
E = s.Symbol("E")
F = s.Symbol("F")
G = s.Symbol("G")
H = s.Symbol("H")
p1 = A * (F - H)
p2 = (A + B) * H
p3 = (C + D) * E
p4 = D * (G - E)
p5 = (A + D) * (E + H)
p6 = (B - D) * (G + H)
p7 = (A - C) * (E + F)
print(A * E + B * G, (p5 + p4 - p2 + p6).simplify())
print(A * F + B * H, (p1 + p2).simplify())
print(C * E + D * G, (p3 + p4).simplify())
print(C * F + D * H, (p1 + p5 - p3 - p7).simplify())
復雜度分析
$$f(N)=7\times f(\frac{N}{2})=7^2\times f(\frac{N}{4})=...=7^k\times f(\frac{N}{2^k})
$$
最終復雜度為$7^{log_2 N}=N^{log_2 7}
$
驗證有效性
使用numpy驗證Strassen算法的有效性:
import timeit
import numpy as np
N = 5000
M = 5000
a = np.random.random((N, M))
test_count = 10
def use_numpy():
ans = np.matmul(a, a.T)
return ans
def use_numpy_strassen():
# numpy使用strassen方法
NN = N // 2
MM = M // 2
A, B, C, D = a[:NN, :MM], a[:NN, MM:], a[NN:, :MM], a[NN:, MM:]
b = a.T
E, F, G, H = b[:NN, :MM], b[:NN, MM:], b[NN:, :MM], b[NN:, MM:]
p1 = np.matmul(A, F - H)
p2 = np.matmul(A + B, H)
p3 = np.matmul(C + D, E)
p4 = np.matmul(D, (G - E))
p5 = np.matmul(A + D, E + H)
p6 = np.matmul(B - D, G + H)
p7 = np.matmul(A - C, E + F)
ans = np.hstack((np.vstack((p5 + p4 - p2 + p6, p3 + p4)), np.vstack((p1 + p2, p1 + p5 - p3 - p7))))
return ans
print(timeit.timeit(use_numpy, number=10))
print(timeit.timeit(use_numpy_strassen, number=10))
one = use_numpy()
three = use_numpy_strassen()
print(one.reshape(-1)[:5])
print(three.reshape(-1)[:5])
實驗說明:這里只使用了一層Strassen,正常的Strassen應該是遞歸的並且需要在空間上進行優化,從而避免太多的空間復制。
實驗結果出人意料:不采用strassen算法僅需要17秒,采用strassen算法需要40秒。
在進行實驗驗證時,最好不要使用python,因為python隱藏的細節太多了。