UOJ 241. 【UR #16】破坏发射台 [矩阵乘法]

时间:2023-03-08 16:31:59
UOJ 241. 【UR #16】破坏发射台 [矩阵乘法]

UOJ 241. 【UR #16】破坏发射台

题意:长度为 n 的环,每个点染色,有 m 种颜色,要求相邻相对不能同色,求方案数。(定义两个点相对为去掉这两个点后环能被分成相同大小的两段)


只想到一个奇怪的线性递推,无法写成矩乘的形式...

正解用状态记录了颜色是否相同

奇环,只考虑相邻,确定第一个的颜色,\(f[i][0/1]\)表示i个与第一个不同/同色的方案数

偶环,再考虑相对,分成两段,同时递推\(i,\frac{n}{2}+i\),\(f[i][0..6]\)来表示

构造矩阵讨论好烦啊

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
typedef long long ll;
const int mo = 998244353;
inline int read() {
char c=getchar(); int x=0,f=1;
while(c<'0'||c>'9') {if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();}
return x*f;
} int n, m, r, g[7][7], f[7][7];
void mul(int a[7][7], int b[7][7]) {
static int c[7][7];
memset(c, 0, sizeof(c));
for(int i=0; i<r; i++)
for(int k=0; k<r; k++) if(a[i][k])
for(int j=0; j<r; j++) if(b[k][j])
c[i][j] = (c[i][j] + (ll) a[i][k] * b[k][j]) %mo;
memcpy(a, c, sizeof(c));
}
void Pow(int a[7][7], int b) {
static int c[7][7];
memset(c, 0, sizeof(c));
for(int i=0; i<r; i++) c[i][i] = 1;
for(; b; b >>= 1, mul(a, a)) if(b & 1) mul(c, a);
memcpy(a, c, sizeof(c));
}
void print(int a[7][7]) {
for(int i=0; i<r; i++) for(int j=0; j<r; j++) printf("%d%c", a[i][j], j==r-1 ? '\n' : ' ');
puts("");
}
namespace odd {
void solve() {
r = 2;
g[0][0] = m-2; g[0][1] = m-1;
g[1][0] = 1; g[1][1] = 0;
Pow(g, n-2);
//print(g);
f[0][0] = m-1; f[1][0] = 0;
mul(g, f);
//print(g);
int ans = (ll) g[0][0] * m %mo;
printf("%d\n", ans);
}
}
namespace even {
int id[5][5];
inline ll cal(int a, int c) {
if(a == 0) return c==0 ? m-3 : m-2;
else return 1;
}
void solve() {
r = 7;
memset(id, -1, sizeof(id));
id[0][0] = 0; id[0][1] = 1; id[0][2] = 2;
id[1][0] = 3; id[1][2] = 4;
id[2][0] = 5; id[2][1] = 6;
for(int a=0; a<3; a++) for(int b=0; b<3; b++) if(~id[a][b])
for(int c=0; c<3; c++) for(int d=0; d<3; d++) if(~id[c][d]) {
int i = id[a][b], j = id[c][d];
if((a && a==c) || (b && b==d)) continue;
if(a == 0 && b == 0) { //printf("hi\n");
if(c && d) g[i][j] = (ll) (m-2) * max(0, m-3) %mo;
else if(!c && !d) g[i][j] = ((ll) max(0, m-4) * max(0, m-4) + max(0, m-3)) %mo;
else if(c || d) g[i][j] = ((ll) max(0, m-3) * max(0, m-3)) %mo;
g[i][j] = max(0, g[i][j]);
} else g[i][j] = cal(a, c) * cal(b, d) %mo;
}
//print(g);
f[id[0][0]][0] = max(0LL, (ll)(m-2) * (m-3)) %mo;
f[id[0][1]][0] = m-2;
f[id[2][0]][0] = m-2;
f[id[2][1]][0] = 1;
n = n/2 - 1;
Pow(g, n-1);
//print(g);
mul(g, f);
//print(g);
int ans = (ll) ((ll) g[0][0] + g[id[1][0]][0] + g[id[0][2]][0] + g[id[1][2]][0]) %mo * m %mo * (m-1) %mo;
printf("%d\n", ans);
}
}
int main() {
freopen("in", "r", stdin);
n = read(); m = read();
if(n == 1) printf("%d\n", m);
else if(n == 2) printf("%lld\n", (ll) m * (m-1) %mo);
if(n & 1) odd::solve();
else even::solve();
}