POJ 3233 Matrix Power Series 二分+矩阵乘法

时间:2022-02-27 16:23:44

链接:http://poj.org/problem?id=3233

题意:给一个N*N的矩阵(N<=30),求S = A + A^2 + A^3 +
… + A^k(k<=10^9)。

思路:非常明显直接用矩阵高速幂暴力求和的方法复杂度O(klogk)。肯定会超时。我採用的是二分的方法, A + A^2 + A^3 + … + A^k=(1+A^(k/2)) *(A + A^2 + A^3 + … + A^(k/2))。这样就能够提出一个(1+A^(k/2)),假设k是奇数,单独处理A^k。

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <map>
#include <cstdlib>
#include <queue>
#include <stack>
#include <vector>
#include <ctype.h>
#include <algorithm>
#include <string>
#include <set>
#define PI acos(-1.0)
#define maxn 35
#define maxm 35
#define INF 10005
#define eps 1e-8
typedef long long LL;
typedef unsigned long long ULL;
using namespace std;
int k,mm;
struct Matrix
{
int n,m;
int a[maxn][maxm];
void init()
{
n=m=0;
memset(a,0,sizeof(a));
}
Matrix operator +(const Matrix &b) const
{
Matrix tmp;
tmp.n=n;
tmp.m=m;
for(int i=0; i<n; i++)
for(int j=0; j<m; j++)
{
tmp.a[i][j]=a[i][j]+b.a[i][j];
tmp.a[i][j]=(tmp.a[i][j]+mm)%mm;
}
return tmp;
}
Matrix operator -(const Matrix &b) const
{
Matrix tmp;
tmp.n=n;
tmp.m=m;
for(int i=0; i<n; i++)
for(int j=0; j<m; j++)
tmp.a[i][j]=a[i][j]-b.a[i][j];
return tmp;
}
Matrix operator *(const Matrix &b) const
{
Matrix tmp;
tmp.init();
tmp.n=n;
tmp.m=b.m;
for(int i=0; i<n; i++)
for(int j=0; j<b.m; j++)
for(int k=0; k<m; k++)
{
tmp.a[i][j]+=a[i][k]*b.a[k][j];
tmp.a[i][j]=(tmp.a[i][j]+mm)%mm;
} return tmp;
}
};//仅仅有当矩阵A的列数与矩阵B的行数相等时A×B才有意义
Matrix M_quick_pow(Matrix m,int k)
{
Matrix tmp;
tmp.n=m.n;
tmp.m=m.m;//m=n才干做高速幂
for(int i=0; i<tmp.n; i++)
{
for(int j=0; j<tmp.n; j++)
{
if(i==j)
tmp.a[i][j]=1;
else tmp.a[i][j]=0;
}
}
while(k)
{
if(k&1)
tmp=tmp*m;
k>>=1;
m=m*m;
}
return tmp;
}
int main()
{
Matrix A,ans,In,res;
while(~scanf("%d%d%d",&A.m,&k,&mm))
{
ans.init();
res.init();
res.m=res.n=ans.m=ans.n=In.m=In.n=A.n=A.m;
for(int i=0; i<In.m; i++)
{
In.a[i][i]=1;
res.a[i][i]=1;
}
for(int i=0; i<A.m; i++)
for(int j=0; j<A.n; j++)
{
scanf("%d",&A.a[i][j]);
A.a[i][j]%=mm;
}
while(k)
{
if(k==1)
{
res=res*A;
}
else
{
if(k%2)
ans=ans+res*M_quick_pow(A,k);
res=res*(In+M_quick_pow(A,k/2));
}
k/=2;
}
ans=ans+res;
for(int i=0; i<ans.m; i++)
{
for(int j=0; j<ans.m; j++)
{
if(j!=0)
printf(" ");
printf("%d",ans.a[i][j]);
}
printf("\n");
}
}
return 0;
}