1004: [HNOI2008]Cards - burnside + DP

时间:2022-04-03 14:20:48

# Description

小春现在很清闲, 面对书桌上的 \(N\) 张牌, 他决定给每张染色, 目前小春只有 \(3\) 种颜色: 红色, 蓝色, 绿色. 他询问 Sun 有

多少种染色方案, Sun 很快就给出了答案. 进一步, 小春要求染出 \(Sr\) 张红色, \(Sb\) 张蓝色, \(Sg\) 张绿色. 他又询问有多少种方

案, Sun 想了一下, 又给出了正确答案. 最后小春发明了 \(M\) 种不同的洗牌法, 这里他又问 Sun 有多少种不同的染色方案.

两种染色方法相同当且仅当其中一种可以通过任意的洗牌法 (即可以使用多种洗牌法, 而每种方法可以使用多次) 洗

成另一种. Sun 发现这个问题有点难度, 决定交给你, 答案可能很大, 只要求出答案除以 \(P\) 的余数 (\(P\) 为质数).

并且数据满足 :

输入数据保证任意多次洗牌都可用这 m 种洗牌法中的一种代替,且对每种洗牌法,都存在一种洗牌法使得能回到原状态。

\(Max{(Sr,Sb, Sg)}<=20\)

Solution

群论中的\(burnside\)定理

参考资料: 《组合数学》 或者 网上博客, 感觉《组合数学》讲的很棒QuQ

首先看\(m\)种洗牌, 任意多次都可以用 一种洗牌法来代替, 满足置换群中 合成运算的封闭性

对每种洗牌法, 都存在一种洗牌法使得回到原状态, 满足置换群中 逆元的封闭性

接下来只需要出现单位元就能让洗牌法构成一个置换群

单位元 : 保持所有位置不变的洗牌法(相当于没洗牌)

如果输入的洗牌法没有单位元, 则手动加上一种洗牌法。 就构成了一个置换群。

然后就可以引用 \(burnside\)定理

设 \(f\) 为一种洗牌法,\(c\) 为一种染色法。

我们可以找到有多少个 \(c\) 使得, \(c\) 通过洗牌法\(f\) 洗牌后不变, 记为\(cnt(f)\)。

则最终答案 :

\[ans= \sum cnt(f) \ \div m
\]

证明略(懒得打噜, 网上有)

接下来就要求出 \(cnt(f)\)了:

对于一种洗牌法, 把每个位置都看作一个点, 洗牌的移动看成一条边。 则每个点的出度和入度都是\(1\)

于是构成了若干个不相交的环。要使洗牌后颜色不变, 那么必有同一个环上的颜色相同。

可以通过简单的\(DP\)求出\(cnt(f)\)了!

总复杂度\(O(MN^3)\) ,BZOJ上跑了\(36ms\)

Code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define ll long long
#define rd read()
using namespace std; const int N = 62; int n, m, sa, sb, sc, mod;
int nxt[N][N], vis[N], f[22][22];
vector<int> v[N]; int read() {
int X = 0, p = 1; char c = getchar();
for (; c > '9' || c < '0'; c = getchar())
if (c == '-') p = -1;
for (; c >= '0' && c <= '9'; c = getchar())
X = X * 10 + c - '0';
return X * p;
} int fpow(int a, int b) {
int res = 1;
for (; b; b >>= 1, a = a * a % mod)
if (b & 1) res = res * a % mod;
return res;
} void init(int x) {
memset(vis, 0, sizeof(vis));
for (int i = 1; i <= n; ++i) if (!vis[i]) {
vis[i] = 1;
int now = nxt[x][i], tmp = 1;
while (now != i) vis[now] = 1, tmp++, now = nxt[x][now];
v[x].push_back(tmp);
}
} int cal(int x) {
memset(f, 0, sizeof(f));
f[sa][sb] = 1;
int tot = 0;
for (int i = 0, up = v[x].size(); i < up; ++i) {
for (int j = 0; j <= sa; ++j)
for (int k = 0; k <= sb; ++k) {
if (n - tot - j - k < v[x][i]) f[j][k] = 0;
if (j + v[x][i] <= sa) (f[j][k] += f[j + v[x][i]][k]) %= mod;
if (k + v[x][i] <= sb) (f[j][k] += f[j][k + v[x][i]]) %= mod;
}
tot += v[x][i];
}
return f[0][0];
} int main()
{
sa = rd; sb = rd; sc = rd; m = rd; mod = rd;
n = sa + sb + sc;
bool flag = false;
for (int i = 1; i <= m; ++i) {
bool is = true;
for (int j = 1; j <= n; ++j) {
nxt[i][j] = rd;
if (nxt[i][j] != j) is = false;
}
if (is) flag = true;
}
if (!flag) {
m++;
for (int i = 1; i <= n; ++i)
nxt[m][i] = i;
}
for (int i = 1; i <= m; ++i)
init(i);
int ans = 0;
for (int i = 1; i <= m; ++i)
ans = (ans + cal(i)) % mod;
ans = ans * fpow(m, mod - 2);
ans = (ans + mod) % mod;
printf("%d\n", ans);
}