题目要求将树分为k个部分,并且每种颜色恰好在同一个部分内,问有多少种方案。
第一步显然我们需要知道哪些点一定是要在一个部分内的,也就是说要求每一个最小的将所有颜色i的点连通的子树。
这一步我们可以将所有有颜色的点丢入优先队列,然后另深度最深的点优先出队。
如果此时这个点的颜色有不只一个点在队列中,那么我们必须要考虑将它的父亲染色,这样才能与其他的该颜色的点连通。
此时有3种情况:
1.如果它的父亲已经被染色且颜色与该点不同,那么此时显然无解;
2.如果它的父亲与它颜色相同,那么此时不做任何操作。
3.如果它的父亲无色,那么将其染色并入队。
经过这样的一番操作后我们已经将必须染色的点染色,那么现在方案数就来自与现在仍然无色的点。
第二步,方案数可以用树形dp来求得。
我们将每个点分为两种状态,记dp[now][0]为点now已经确定颜色的方案数,dp[now][1]为未确定颜色的方案数。
接下来分类讨论如何求这两个状态的dp值:
1.如果这个点原本就有颜色
那么此时显然dp[now][1]=0,dp[now][0]=所有子节点i的(dp[i][0]+dp[i][1])的乘积,因为如果子节点已经染色,那显然状态可以继承,如果未染色,那么显然此时必须被点now染色。
2.如果这个点未被染色
此时的dp[now][1]就等于情况1的dp[now][0],而dp[now][0]则要在所有子节点中选择一个子节点,令点now被这个子节点i染色,那首先前提显然是i节点已经确定颜色,所以此时枚举每个子节点,
对dp[i][0]*dp[now][1]/(dp[i][0]+dp[i][1])求和。
以下为代码:
#include<bits/stdc++.h>
using namespace std;
const long long mod=998244353;
int i,i0,n,m,k,col[300005],dep[300005],fa[300005],cnt[300005];
vector<int>mp[300005];
void dfs(int now,int d)
{
dep[now]=d;
for(int i:mp[now])if(!dep[i])dfs(i,d+1),fa[i]=now;
return;
}
struct node
{
int x,d;
bool operator<(node a)const{return d<a.d;}
};
priority_queue<node>q;
long long dp[300005][2];
void extgcd(long long a,long long b,long long& d,long long& x,long long& y)
{
if(!b){d=a;x=1;y=0;}
else{extgcd(b,a%b,d,y,x);y-=x*(a/b);}
}
long long inv(long long a,long long n)
{
long long d,x,y;
extgcd(a,n,d,x,y);
return d==1?(x+n)%n:-1;
}
void dfs0(int now)
{
dp[now][0]=dp[now][1]=1;
for(auto i:mp[now])
{
if(i==fa[now])continue;
dfs0(i);
dp[now][1]*=(dp[i][0]+dp[i][1]);
dp[now][1]%=mod;
}
if(col[now])
{
dp[now][0]=dp[now][1];
dp[now][1]=0;
}
if(!col[now])
{
dp[now][0]=0;
for(auto i:mp[now])
{
if(i==fa[now])continue;
dp[now][0]+=dp[now][1]*inv(dp[i][0]+dp[i][1],mod)%mod*dp[i][0]%mod;
dp[now][0]%=mod;
}
}
return;
}
int main()
{
scanf("%d %d",&n,&k);
for(i=1;i<=n;i++)scanf("%d",&col[i]),cnt[col[i]]++;
for(i=1;i<n;i++)
{
int x,y;
scanf("%d %d",&x,&y);
mp[x].push_back(y);
mp[y].push_back(x);
}
dfs(1,1);
for(i=1;i<=n;i++)if(col[i])q.push({i,dep[i]});
while(!q.empty())
{
node tmp=q.top();
q.pop();
if(col[fa[tmp.x]]==col[tmp.x])cnt[col[tmp.x]]--;
else
{
if(cnt[col[tmp.x]]!=1)
{
if(!col[fa[tmp.x]])
{
col[fa[tmp.x]]=col[tmp.x];
q.push({fa[tmp.x],dep[fa[tmp.x]]});
}
else
{
printf("0\n");
return 0;
}
}
}
}
dfs0(1);
printf("%lld\n",dp[1][0]);
return 0;
}