Strassen算法


$$`
\left[
\begin{matrix}
A ,B\
C,D
\end{matrix}
\right]
\times
\left[
\begin{matrix}
E,F\
G,H
\end{matrix}
\right]

\left[
\begin{matrix}
AE+BG, AF+BH\
CE+DG,CF+DH
\end{matrix}
\right]
`$$

Strassen算法於1969年由德國數學家Strassen提出,該方法引入七個中間變量,每個中間變量都只需要進行一次乘法運算。而朴素算法卻需要進行8次乘法運算。

原理

Strassen算法的原理如下所示,使用sympy驗證Strassen算法的正確性

import sympy as s

A = s.Symbol("A")
B = s.Symbol("B")
C = s.Symbol("C")
D = s.Symbol("D")
E = s.Symbol("E")
F = s.Symbol("F")
G = s.Symbol("G")
H = s.Symbol("H")
p1 = A * (F - H)
p2 = (A + B) * H
p3 = (C + D) * E
p4 = D * (G - E)
p5 = (A + D) * (E + H)
p6 = (B - D) * (G + H)
p7 = (A - C) * (E + F)

print(A * E + B * G, (p5 + p4 - p2 + p6).simplify())
print(A * F + B * H, (p1 + p2).simplify())
print(C * E + D * G, (p3 + p4).simplify())
print(C * F + D * H, (p1 + p5 - p3 - p7).simplify())

復雜度分析

$$f(N)=7\times f(\frac{N}{2})=7^2\times f(\frac{N}{4})=...=7^k\times f(\frac{N}{2^k})$$
最終復雜度為$7^{log_2 N}=N^{log_2 7}$

驗證有效性

使用numpy驗證Strassen算法的有效性:

import timeit

import numpy as np

N = 5000
M = 5000
a = np.random.random((N, M))
test_count = 10


def use_numpy():
    ans = np.matmul(a, a.T)
    return ans


def use_numpy_strassen():
    # numpy使用strassen方法
    NN = N // 2
    MM = M // 2
    A, B, C, D = a[:NN, :MM], a[:NN, MM:], a[NN:, :MM], a[NN:, MM:]
    b = a.T
    E, F, G, H = b[:NN, :MM], b[:NN, MM:], b[NN:, :MM], b[NN:, MM:]
    p1 = np.matmul(A, F - H)
    p2 = np.matmul(A + B, H)
    p3 = np.matmul(C + D, E)
    p4 = np.matmul(D, (G - E))
    p5 = np.matmul(A + D, E + H)
    p6 = np.matmul(B - D, G + H)
    p7 = np.matmul(A - C, E + F)
    ans = np.hstack((np.vstack((p5 + p4 - p2 + p6, p3 + p4)), np.vstack((p1 + p2, p1 + p5 - p3 - p7))))
    return ans


print(timeit.timeit(use_numpy, number=10))
print(timeit.timeit(use_numpy_strassen, number=10))
one = use_numpy()
three = use_numpy_strassen()
print(one.reshape(-1)[:5])
print(three.reshape(-1)[:5])

實驗說明:這里只使用了一層Strassen,正常的Strassen應該是遞歸的並且需要在空間上進行優化,從而避免太多的空間復制。
實驗結果出人意料:不采用strassen算法僅需要17秒,采用strassen算法需要40秒。
在進行實驗驗證時,最好不要使用python,因為python隱藏的細節太多了。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM