[POJ3233]Matrix Power Series 分治+矩阵

时间:2023-12-20 22:10:08

本文为博主原创文章,欢迎转载,请注明出处 www.cnblogs.com/yangyaojia

[POJ3233]Matrix Power Series 分治+矩阵

题目大意

A为n×n(n<=30)的矩阵,让你求

\(\sum\limits_{i=1}^{k}A^i\)

并将答案对取模p

输入格式:

有多组测试数据,其中第一行有3个正整数,为n,k(k<=\(10^9\)),p(p<=\(10^4\))

后面有n行,每行n个数。

输出格式:

输出最后答案的矩阵。

输入输出样例

input

2 2 4

0 1

1 1

output

1 2

2 3

解题分析

首先,这道题k有10^9,肯定不能O(n)推过去,但我们发现这个随着k递增ans是严格递增的,并且满足一定的规律,在这里我们想到了分治。

我们先取答案的一半,尝试怎么凑出这个答案

我们让 \(a=\sum\limits_{i=1}^{k\over2}A^i,b=A^{k\over2}\)

我们发现,

如果 $k\mod2==0 $ 则

\(\sum\limits_{i=1}^{k}A^i=\sum\limits_{i=1}^{k\over2}A^i+\sum\limits_{i={k\over2}+1}^{k}A^i=\sum\limits_{i=1}^{k\over2}A^i+(\sum\limits_{i=1}^{k\over2}A^i*A^{k\over2})=a+a*b\)

怎么理解呢,可以认为我们用\(A^{k\over2}\)将\(\sum\limits_{i=1}^{k\over2}A^i\)的指数补成\(\sum\limits_{i={k\over2}+1}^{k}A^i\)。

所以,我们可以很容易推出当k为奇数的情况,只要将最后一位补上就可以。

\(\sum\limits_{i=1}^{k}A^i=\sum\limits_{i=1}^{k\over2}A^i+(\sum\limits_{i=1}^{k\over2}A^i*A^{k\over2})+A^k=a+a*b+b^2*A\)

这样,就可以在\(O(log_2n)\)出解了。

矩阵乘法加法就不多说了。

#include <cstdio>
#include <iostream>
#include <cmath>
#include <queue>
#include <algorithm>
#include <cstring>
#include <climits>
#define MAXN 50+10 using namespace std;
int p,n,k;
struct matrix{
int n,m;
int data[MAXN][MAXN];
void read(int x,int y)
{
n=x,m=y;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
scanf("%d",&data[i][j]);
}
void print()
{
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
printf("%d ",data[i][j]);
printf("\n");
}
}
matrix operator * (matrix b)
{
matrix ans;
memset(ans.data,0,sizeof(ans.data));
ans.n=n;ans.m=b.m;
for(int i=1;i<=ans.n;i++) for(int j=1;j<=ans.m;j++) for(int k=1;k<=m;k++)
ans.data[i][j]+=data[i][k]*b.data[k][j],ans.data[i][j]%=p;
return ans;
}
matrix operator + (matrix b)
{
matrix ans;
memset(ans.data,0,sizeof(ans.data));
ans.n=n;ans.m=m;
for(int i=1;i<=ans.n;i++) for(int j=1;j<=ans.m;j++)
ans.data[i][j]=(data[i][j]+b.data[i][j])%p;
return ans;
}
}a,b,c,z;
void halfsort(matrix x,int y)
{
if(y==1) {c=x;b=x;return ;}
else
{
halfsort(x,y/2);
if(y%2==0)
{
b=b+(b*c);
c=c*c;
}else
{
b=b+(b*c);
b=b+(x*c*c);
c=c*c;c=c*x;
}
}
}
int main()
{
while(scanf("%d%d%d",&n,&k,&p)!=EOF)
{
a=z;b=z;c=z;
a.read(n,n);
halfsort(a,k);
b.print();
}
return 0;
}