Strassen优化矩阵乘法(复杂度O(n^lg7))

时间:2025-01-03 21:35:38

按照算法导论写的

还没有测试复杂度到底怎么样

不过这个真的很卡内存,挖个坑,以后写空间优化

还有Matthew Anderson, Siddharth Barman写了一个关于矩阵乘法的论文

《The Coppersmith-Winograd Matrix Multiplication Algorithm》

提出了矩阵乘法的O(n^2.37)算法,有时间再膜吧orz

#include <iostream>
#include <cstring>
#include <cstdio>
#include <iomanip>
using namespace std;
const int maxn = ;
struct Matrix
{
double v[maxn][maxn];
int n, m;
Matrix() { memset(v, , sizeof(v));}
Matrix operator +(const Matrix& B)
{
Matrix C; C.n = n; C.m = m;
for(int i = ; i < n; i++)
for(int j = ; j < n; j++)
C.v[i][j] = v[i][j] + B.v[i][j];
return C;
}
Matrix operator -(const Matrix& B)
{
Matrix C; C.n = n; C.m = m;
for(int i = ; i < n; i++)
for(int j = ; j < n; j++)
C.v[i][j] = v[i][j] - B.v[i][j];
return C;
}
Matrix operator *(const Matrix &B)
{
Matrix C; C.n = n; C.m = B.m;
for(int i = ; i < n; i++)
for(int j = ; j < m; j++)
{
if(v[i][j] == ) continue; //矩阵常数优化
for(int k = ; k < m; k++)
C.v[i][k] += v[i][j]*B.v[j][k];
}
return C;
}
void prepare() //将矩阵转换成2^k的形式,便于分治
{
int _n = ;
while(_n < n) _n <<= ;
while(_n < m) _n <<= ;
for(int i = ; i < n; i++)
for(int j = m; j < _n; j++)
v[i][j] = ;
for(int i = n; i < _n; i++)
for(int j = ; j < _n; j++)
v[i][j] = ;
n = m = _n;
}
void read()
{
cin>>n>>m;
for(int i = ; i < n; i++)
for(int j = ; j < m; j++)
cin>>v[i][j];
}
Matrix get(int i1, int j1, int i2, int j2)
{
Matrix C; C.n = i2-i1+; C.m = j2-j1+;
for(int i = i1-; i < i2; i++)
for(int j = j1-; j < j2; j++)
C.v[i-i1+][j-j1+] = v[i][j];
return C;
}
void give(Matrix &B, int i1, int j1, int i2, int j2)
{
for(int i = i1-; i < i2; i++)
for(int j = j1-; j < j2; j++)
v[i][j] = B.v[i-i1+][j-j1+];
}
void print()
{
for(int i = ; i < n; i++)
{
for(int j = ; j < m; j++)
cout<<setw()<<v[i][j];
cout<<endl;
} }
}A, B; Matrix Strassen(Matrix &X, Matrix &Y) //分治+利用多次矩阵相加代替矩阵相乘优化,复杂度O(n^2.81)
{
if(X.n == ) return X*Y;
int n = X.n;
Matrix A[][], B[][], S[], P[];
A[][] = X.get(, , n/, n/); A[][] = X.get(, n/+, n/, n);
A[][] = X.get(n/+, , n, n/); A[][] = X.get(n/+, n/+, n, n);
B[][] = Y.get(, , n/, n/); B[][] = Y.get(, n/+, n/, n);
B[][] = Y.get(n/+, , n, n/); B[][] = Y.get(n/+, n/+, n, n);
//for(int i = 0; i < 2; i++) { for(int j = 0; j < 2; j++) A[i][j].print(); cout<<endl; }
//for(int i = 0; i < 2; i++) { for(int j = 0; j < 2; j++) B[i][j].print(); cout<<endl; }
S[] = B[][] - B[][]; S[] = A[][] + A[][];
S[] = A[][] + A[][]; S[] = B[][] - B[][]; S[] = A[][] + A[][];
S[] = B[][] + B[][]; S[] = A[][] - A[][];
S[] = B[][] + B[][]; S[] = A[][] - A[][]; S[] = B[][] + B[][];
P[] = Strassen(A[][], S[]); P[] = Strassen(S[], B[][]);
P[] = Strassen(S[], B[][]); P[] = Strassen(A[][], S[]);
P[] = Strassen(S[], S[]); P[] = Strassen(S[], S[]); P[] = Strassen(S[], S[]);
//for(int i = 0; i < 7; i++) P[i].print(); cout<<endl;
B[][] = P[] + P[] - P[] + P[]; B[][] = P[] + P[];
B[][] = P[] + P[]; B[][] = P[] + P[] - P[] - P[];
//for(int i = 0; i < 2; i++) { for(int j = 0; j < 2; j++) B[i][j].print(); }
X.give(B[][], , , n/, n/); X.give(B[][], , n/+, n/, n);
X.give(B[][], n/+, , n, n/); X.give(B[][], n/+, n/+, n, n);
return X;
} int main()
{
Matrix C;
A.read(); B.read();
int n = A.n, m = B.m;
A.prepare(); B.prepare();
C = Strassen(A, B); C.n = n; C.m = m; C.print();
}