【BZOJ 1016】【JSOI 2008】最小生成树计数

时间:2023-11-24 17:39:08

http://www.lydsy.com/JudgeOnline/problem.php?id=1016

统计每一个边权在最小生成树中使用的次数,这个次数在任何一个最小生成树中都是固定的(归纳证明)。

在同一个边权上对所有边权为这个的边暴力统计(可以用矩阵树定理),然后用并查集把这个边权的所有边贡献的连通性都加上,再统计下一个边权。

最后把答案乘起来。

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 103;
const int M = 1003;
const int p = 31011; struct Edge {
int u, v, e;
bool operator < (const Edge &A) const {
return e < A.e;
}
} E[M];
int fa[N], n, m, sz[N], val[N], tot[N], l[N], r[N]; int find(int x) {return fa[x] == x ? x : find(fa[x]);} void merge(int x, int y) {
fa[x] = y; sz[y] += sz[x];
while (fa[y] != y) {
y = fa[y];
sz[y] += sz[x];
}
} void cut(int x, int y) {
if (fa[x] == y) {
fa[x] = x;
sz[y] -= sz[x];
while (fa[y] != y) {
y = fa[y];
sz[y] -= sz[x];
}
} else {
fa[y] = y;
sz[x] -= sz[y];
while (fa[x] != x) {
x = fa[x];
sz[x] -= sz[y];
}
}
} int dfsl, dfsr, dfstot, sum; void dfs(int tmp, int nowtot) {
if (nowtot == dfstot) {++sum; if (sum == p) sum = 0; return;}
if (tmp > dfsr || dfstot - nowtot > dfsr - tmp + 1) return;
dfs(tmp + 1, nowtot);
int u = find(E[tmp].u), v = find(E[tmp].v);
if (u != v) {
if (sz[u] < sz[v]) merge(u, v); else merge(v, u);
dfs(tmp + 1, nowtot + 1);
cut(u, v);
}
} int in() {
int k = 0; char c = getchar();
for (; c < '0' || c > '9'; c = getchar());
for (; c >= '0' && c <= '9'; c = getchar())
k = k * 10 + c - 48;
return k;
} int main() {
n = in(); m = in();
int i;
for (i = 1; i <= m; ++i) {E[i].u = in(); E[i].v = in(); E[i].e = in();}
stable_sort(E + 1, E + m + 1); int x, y, num = 0, cnt = 0; val[0] = -1;
for (i = 1; i <= n; ++i) fa[i] = i, sz[i] = 1;
for (i = 1; i <= m; ++i) {
x = find(E[i].u); y = find(E[i].v);
if (E[i].e != val[num]) {
r[num] = i - 1;
val[++num] = E[i].e;
l[num] = i;
}
if (x != y) {
++tot[num];
if (sz[x] < sz[y]) merge(x, y); else merge(y, x);
++cnt;
if (cnt == n - 1)
break;
}
}
if (cnt < n - 1) {puts("0"); return 0;}
for (; i <= m && E[i].e == val[num]; ++i);
r[num] = i - 1; for (i = 1; i <= n; ++i) fa[i] = i, sz[i] = 1;
ll ans = 1;
for (i = 1; i <= num; ++i) {
sum = 0; dfsl = l[i]; dfsr = r[i]; dfstot = tot[i];
dfs(dfsl, 0);
for (int j = dfsl; j <= dfsr; ++j) {
x = find(E[j].u); y = find(E[j].v);
if (x != y) if (sz[x] < sz[y]) merge(x, y); else merge(y, x);
}
ans = ans * sum % p;
}
printf("%lld\n", ans);
return 0;
}