POJ - 3233 Matrix Power Series (矩阵等比二分求和)

时间:2023-03-10 05:49:47
POJ - 3233 Matrix Power Series (矩阵等比二分求和)

Description

Given a n × n matrix A and a positive integer k, find the sum
S = A + A2 + A3 + … +
Ak
.

Input

The input contains exactly one test case. The first line of input contains three positive integers
n (n ≤ 30), k (k ≤ 109) and m (m < 104). Then follow
n lines each containing n nonnegative integers below 32,768, giving
A’s elements in row-major order.

Output

Output the elements of S modulo m in the same way as A is given.

Sample Input

2 2 4
0 1
1 1

Sample Output

1 2
2 3

题意:求矩阵总和

思路:矩阵高速幂取模,和等比数列矩阵求和,这里说一下怎么二分求矩阵的等比序列和,设矩阵为A。次数为k

设sum(k) = A^1 + A^2 + A^3 + ..... + A^K,那么为了可以二分递归下去我们首先拆出个( A^1 + A^2 + ... + A^(k/2) ,可以非常easy得到:sum(k) = sum(k/2) * (A^(k/2) + 1)

在代码里是假设这个次数是奇数的话会少算一个A^k,所以要记得加上

#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cmath>
typedef long long ll;
using namespace std;
const int maxn = 32; int m, n;
struct Matrix {
int v[maxn][maxn];
Matrix() {}
Matrix(int x) {
init();
for (int i = 0; i < maxn; i++)
v[i][i] = x;
}
void init() {
memset(v, 0, sizeof(v));
}
Matrix operator *(Matrix const &b) const {
Matrix c;
c.init();
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
for (int k = 0; k < n; k++)
c.v[i][j] = (c.v[i][j] + (v[i][k]*b.v[k][j])) % m;
return c;
}
Matrix operator ^(int b) {
Matrix a = *this, res(1);
while (b) {
if (b & 1)
res = res * a;
a = a * a;
b >>= 1;
}
return res;
}
} u(1); Matrix Add(Matrix a, Matrix b) {
for (int i = 0; i < maxn; i++)
for (int j = 0; j < maxn; j++)
a.v[i][j] = (a.v[i][j]+b.v[i][j]) % m;
return a;
} Matrix BinarySum(Matrix a, int n) {
if (n == 1)
return a;
if (n & 1)
return Add(BinarySum(a, n-1), a^n);
else return BinarySum(a, n>>1) * Add(u, a^(n>>1));
} int main() {
int k;
Matrix a, ans;
while (scanf("%d%d%d", &n, &k, &m) != EOF) {
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
scanf("%d", &a.v[i][j]);
ans = BinarySum(a, k);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
printf("%d%c", ans.v[i][j], (j==n-1)? '\n':' ');
}
return 0;
}