POJ-3233 Matrix Power Series 矩阵A^1+A^2+A^3...求和转化

时间:2023-12-20 22:10:20

S(k)=A^1+A^2...+A^k.

保利求解就超时了,我们考虑一下当k为偶数的情况,A^1+A^2+A^3+A^4...+A^k,取其中前一半A^1+A^2...A^k/2,后一半提取公共矩阵A^k/2后可以发现也是前一半A^1+A^2...A^k/2。因此我们可以考虑只算其中一半,然后A^k/2用矩阵快速幂处理。对于k为奇数,只要转化为k-1+A^k即可。n为矩阵数量,m为矩阵大小,复杂度O[(logn*logn)*m^3]

#include <iostream>
#include <algorithm>
#include <string>
#include <map>
#include <set>
#include <vector>
#include <cmath>
#define LL long long
using namespace std; struct mx
{
    LL n, m;
    LL c[][];//需要根据题目开大
    void initMath(LL _n)//初始化方阵
    {
        m = n = _n;
    }
    void initOne(LL _n)//初始化单位矩阵
    {
        m = n = _n;
        for (LL i = ; i<n; i++)
            for (LL j = ; j<m; j++)
                c[i][j] = (i == j);
    }
    void print()//测试打印
    {
        for (LL i = ; i<n; i++)
        {
            for (LL j = ; j < m; j++)
            {
                cout << c[i][j];
                if (j != m - )cout << ' ';
            }
                
            cout << endl;
        }
    }
};
int mod = ;
mx Mut(mx a, mx b)
{
    mx c;
    c.n = a.n, c.m = b.m;
    for (LL i = ; i<a.n; i++)
        for (LL j = ; j<b.m; j++)
        {
            LL sum = ;
            for (LL k = ; k<b.n; k++)
                sum += a.c[i][k] * b.c[k][j], sum %= mod;
            c.c[i][j] = sum;
        }
    return c;
}
mx fastMi(mx a, LL b)
{
    mx mut; mut.initOne(a.n);
    while (b)
    {
        if (b % != )
            mut = Mut(mut, a);
        a = Mut(a, a);
        b /= ;
    }
    return mut;
}
LL n, k;
mx a, ans, b;
mx s(LL kx)
{
    if (kx == )
    {
        return a;
    }
    if (kx % ==)
    {
        mx p = s(kx / );
        mx y = fastMi(a, kx/);
        y = Mut(y,p);
        for (int i = ; i < n; i++)for (int j = ; j < n; j++)
        {
            y.c[i][j] += p.c[i][j];
            y.c[i][j] %= mod;
        }
        return y;
    }
    else
    {
        mx p = s(kx-);
        mx y = fastMi(a, kx);
        for (int i = ; i < n; i++)for (int j = ; j < n; j++)
        {
            y.c[i][j] += p.c[i][j];
            y.c[i][j] %= mod;
        }
        return y;
    }
}
int main(int argc, const char * argv[]) {
    cin.sync_with_stdio(false);
    int t;
    cin >> t;
    while (t--)
    {
        cin >> n >> k;
        b.initMath(n);
        ans.initMath(n);
        a.initMath(n);
        for(int i=;i<n;i++)
            for (int j = ; j < n; j++)
            {
                cin >> a.c[i][j];
            }
        ans = s(k);
        ans.print();
    }
    return ;
}