5分钟掌握矩阵乘法的Strassen算法
By Long Luo
机器学习中需要训练大量数据,涉及大量复杂运算,例如卷积、矩阵等。这些复杂运算不仅多,而且每次计算的数据量很大,如果能针对这些运算进行优化,可以大幅提高性能。
一、矩阵乘法
假设 \(A\) 为 \(m \times p\) 的矩阵,\(B\) 为 \(p \times n\) 的矩阵,那么称 \(m \times n\) 的矩阵 \(C\) 为矩阵 \(A\) 与 \(B\) 的乘积,记作 \(C = AB\),称为矩阵积(\(\textit{Matrix Product}\))。
其中矩阵 \(C\) 中的第 \(i\) 行第 \(j\) 列元素可以表示为:
\[ (AB)_{ij} = \sum_{k=1}^{p}{a_{ik}b_{kj}} = a_{i1}b_{1j} + a_{i2}b_{2j} + \cdots + a_{ip}b_{pj} \]
如下图所示:
假如在矩阵 \(A\) 和矩阵 \(B\) 中,\(m=p=n=N\),那么完成 \(C=AB\) 需要多少次乘法呢?
- 对于每一个行向量 \(r\) ,总共有 \(N\) 行;
- 对于每一个列向量 \(c\) ,总共有 \(N\) 列;
- 计算它们的内积,总共有 \(N\) 次乘法计算。
综合可以看出,矩阵乘法的算法复杂度是:\(O(N^3)\)。
二、Strassen算法
那么有没有比 \(O(N^{3})\) 更快的算法呢?
1969年,Volker Strassen 提出了第一个算法时间复杂度低于 \(O(N^{3})\) 矩阵乘法算法,算法复杂度为 \(O(n^{log_{2}^{7}})=O(n^{2.807})\) 。从下图可知,\(\textit{Strassen}\) 算法只有在对于维数比较大的矩阵( \(N > 300\) ) ,性能上才有很大的优势,可以减少很多乘法计算。
\(\textit{Strassen}\) 算法证明了矩阵乘法存在时间复杂度低于 \(O(N^{3})\) 的算法的存在,后续学者不断研究发现新的更快的算法,截止目前时间复杂度最低的矩阵乘法算法是 Coppersmith-Winograd 方法的一种扩展方法,其算法复杂度为 \(O(n^{2.375})\) 。
三、Strassen原理详解
假设矩阵 \(A\) 和矩阵 \(B\) 都是 \(N \times N (N = 2^{n})\) 的方矩阵,求 \(C = AB\) ,如下所示:
\[ \begin{aligned} A = \left [\begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix} \right] \\ B = \left [\begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \\ \end{matrix} \right] \\ C = \left [\begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \\ \end{matrix} \right] \end{aligned} \]
其中
\[ \begin{bmatrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{bmatrix} = \begin{bmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{bmatrix} \cdot \begin{bmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{bmatrix} \]
矩阵 \(C\) 可以通过下列公式求出:
\[ \begin{aligned} C_{11} = A_{11} \cdot B_{11} + A_{12} \cdot B_{21} \\ C_{12} = A_{11} \cdot B_{12} + A_{22} \cdot B_{21} \\ C_{21} = A_{21} \cdot B_{11} + A_{22} \cdot B_{21} \\ C_{22} = A_{21} \cdot B_{12} + A_{22} \cdot B_{22} \\ \end{aligned} \]
从上述公式我们可以得出,计算 \(2\) 个 \(n \times n\) 的矩阵相乘需要 \(2\) 个 \(\frac{n}{2} \times \frac{n}{2}\) 的矩阵 \(8\) 次乘法和 \(4\) 次加法。
我们使用 \(T(n)\) 表示 \(n \times n\) 矩阵乘法的时间复杂度,那么我们可以根据上面的分解得到下面的递推公式:
\[ T(n) = 8 \times T(\frac{n}{2}) + O(n^{2}) \]
其中:
- \(8T(\frac{n}{2})\) 表示 \(8\) 次矩阵乘法,而且相乘的矩阵规模降到了 \(\frac{n}{2}\)。
- \(O(n^{2})\) 表示 \(4\) 次矩阵加法的时间复杂度以及合并矩阵 \(C\) 的时间复杂度。
最终可计算得到 \(T(n)=O(n^{log_{2}^{8}})=O(n^{3})\) 。
可以看出每次递归操作都需要 \(8\) 次矩阵相乘,而这正是瓶颈的来源。相比加法,矩阵乘法是非常慢的,于是我们想到能不能减少矩阵相乘的次数呢?
答案是当然可以!!!
\(\textit{Strassen}\) 算法正是从这个角度出发,实现了降低算法复杂度!