矩阵乘法——Strassen算法

《算法导论》——Strassen算法

矩阵乘法

接触过线性代数的读者,对于矩阵乘法想必一定不陌生。若A=(aij)B=(bij)nn的方阵,则对i,j,,n,定义乘积C=AB中的元素cij为:

cij=k=1naikbkj

因此,我们可以根据矩阵乘法的定义给出矩阵乘法的伪代码。它接收nn的矩阵AB,返回它们的乘积——nn的矩阵C,并且假设每个矩阵都有一个属性rows,表示矩阵的行数。


矩阵乘法——Strassen算法

不难看出,由于三重for循环都恰好执行n步,而第7行每次执行都花费常量时间。因此,SQUARE-MATRIX-MULTIPLY的时间复杂度为θ(n3),即矩阵乘法的朴素实现需要花费θ(n3)时间。你可能因此认为任何矩阵乘法都要花费Ω(n3)时间,因为矩阵乘法的自然定义就需要进行这么多次的标量乘法。而在学术界,也的确在很长一段时间内,很少人敢设想一个算法能渐近快于平凡算法SQUARE-MATRIX-MULTIPLY,直至Strassen大神的出现。

算法流程

Strassen算法采用分治法解决矩阵乘积问题,并通过排列组合的技巧使得分治法产生的递归树不那么“茂盛”以减少矩阵乘法的次数。Strassen算法并不直观,它包含4个步骤:

  1. 将输入矩阵AB和输出矩阵C通过以下方式分解为n2n2的子矩阵;

    A=[A11A12A21A22],B=[B11B12B21B22],C=[C11C12C21C22]

  2. 创建10个n2n2的矩阵S1,S2,,S10,每个矩阵保存步骤1中创建的两个子矩阵的和或差,时间复杂度为Θ(n2)

  3. 用步骤1中创建的子矩阵和步骤2中创建的10个矩阵,递归地计算7个矩阵积P1,P2,,P7。每个矩阵Pi都是n2n2的;
  4. 通过Pi矩阵的不同组合进行加减计算,计算出矩阵C的子矩阵C11,C12,C21,C22,时间复杂度为Θ(n2)

是不是感觉很抽象?一顿猛如虎的操作,就能完成矩阵乘积计算了?没错,就是这么。接下来,为了帮助大家掌握这种操作,就再看看Strassen算法的细节。在步骤2中,创建如下10个矩阵:

S1=B12B22

S2=A11+A12

S3=A21+A22

S4=B21B11

S5=A11+A22

S6=B11+B22

S7=A12A22

S8=B21+B22

S9=A11A21

S10=B11+B22

由于必须进行10次n2n2的加减法,因此,该步骤花费Θ(n2)

在步骤三中,递归地计算7次n2n2矩阵的乘法,如下所示:

P1=A11S1=A11B12A11B22

P2=S2B22=A11B22+A12B22

P3=S3B11=A21B11+A22B11

P4=A22S4=A22B21A22B11

P5=S5S6=A11B11+A11B22+A22B11+A22B22

P6=S7S8=A12B21+A12B22A22B21A22B22

P7=S9S10=A11B11+A11B12A21B11A21B12

步骤4对步骤3创建的Pi矩阵进行加减法运算,计算出C的4个n2n2的子矩阵。

C11=P5+P4P2+P6=A11B11+A12B21

C12=P1+P2=A11B12+A12B22

C21=P3+P4=A21B11+A22B21

C22=P5+P1P3P7=A22B22+A21B12

如此,我们便获得矩阵AB的乘积——矩阵C

算法分析

之前说过,Strassen算法的时间复杂度是优于朴素计算的,可是,它到底是多少呢?我们不妨再回到Strassen算法的流程。当n>1时,步骤1、2和4共花费θ(n2)时间,步骤3要求7次n2n2矩阵的乘法。因此,我们得到如下描述Strassen算法运行时间T(n)的递归式:

T(n)={θ(1)n=17T(n/2)+θ(n2)n>1

求解上式可得,T(n)=θ(nlg7)

算法实现

废话千句,不如代码两行,接下来直接上Strassen算法的实现。(注意,如果n不是2的幂,可以采取对原矩阵填充0的方式,使n扩展到2的幂)。


矩阵乘法——Strassen算法

算法总结

Strassen算法发表于1969年,它的发表引起了很大的轰动。在此之前,很少人敢设想一个算法能渐近快于平凡算法SQUARE-MATRIX-MULTIPLY。矩阵乘法的上界自此被改进了。到目前为止,nn矩阵相乘的渐近复杂性最优的算法是Coppersmith和Winograd提出的,运行时间是O(n2.376)