题意:
有一棵n个点的无根树,节点依次编号为1到n,其中节点i的权值为vi,
定义一棵树的价值为它所有点的权值的异或和。
现在对于每个[0,m)的整数k,请统计有多少T的非空连通子树的价值等于k。
Sample Input
2
4 4
2 0 1 3
1 2
1 3
1 4
4 4
0 1 3 1
1 2
1 3
1 4
Sample Output
3 3 2 3
2 4 2 3
令f[i][j]表示以i为根的子树中异或和为j的联通块个数,v为i儿子
f[i][j]+=f[i][k]*f[v][l] (k^l==j)
发现转移其实可以写成这种形式:
$C_i=\sum_{j^k=i}A_j*B_k$
这和卷积有点类似,不过运算改成了异或
这里就要用到FWT(快速沃尔什变换)
就可以做到nlogn转移
转移完后记得在加上原来的f[i][j],因为你可以不选v
复杂度为$O(n^{2}logn)$
卡常,少取模,不要定义long long变量
这题还可以点分治
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
using namespace std;
struct Node
{
int next,to;
}edge[];
int num,head[],Mod=1e9+,inv2,tmp[],a[][],ans[],n,m;
int gi()
{
char ch=getchar();
int x=;
while (ch<''||ch>'') ch=getchar();
while (ch>=''&&ch<='')
{
x=x*+ch-'';
ch=getchar();
}
return x;
}
void add(int u,int v)
{
num++;
edge[num].next=head[u];
head[u]=num;
edge[num].to=v;
}
int qpow(int x,int y)
{
int res=;
while (y)
{
if (y&) res=1ll*res*x%Mod;
x=1ll*x*x%Mod;
y/=;
}
return res;
}
void FWT(int *A,int len)
{int i,j,k;
for (i=;i<m;i<<=)
{
for (j=;j<m;j+=(i<<))
{
for (k=;k<i;k++)
{
int x=A[j+k],y=A[j+k+i];
A[j+k]=x+y;
if (A[j+k]>=Mod) A[j+k]-=Mod;
A[j+k+i]=x-y+Mod;
if (A[j+k+i]>=Mod) A[j+k+i]-=Mod;
}
}
}
}
void UFWT(int *A,int len)
{int i,j,k;
for (i=;i<m;i<<=)
{
for (j=;j<m;j+=(i<<))
{
for (k=;k<i;k++)
{
int x=A[j+k],y=A[j+k+i];
A[j+k]=1ll*(x+y)*inv2%Mod;
A[j+k+i]=1ll*(x-y+Mod)*inv2%Mod;
}
}
}
}
void DP(int x,int y)
{int i;
for (i=;i<m;i++)
tmp[i]=a[x][i];
FWT(a[x],m);
FWT(a[y],m);
for (i=;i<m;i++)
a[x][i]=1ll*a[x][i]*a[y][i]%Mod;
UFWT(a[x],m);
for (i=;i<m;i++)
{
a[x][i]=a[x][i]+tmp[i];
if (a[x][i]>=Mod) a[x][i]-=Mod;
}
}
void dfs(int x,int pa)
{int i;
for (i=head[x];i;i=edge[i].next)
{
int v=edge[i].to;
if (v!=pa)
{
dfs(v,x);
DP(x,v);
}
}
for (i=;i<m;i++)
{
ans[i]=ans[i]+a[x][i];
if (ans[i]>=Mod) ans[i]-=Mod;
}
}
int main()
{int T,i,x,u,v,j;
cin>>T;
inv2=qpow(,Mod-);
while (T--)
{
memset(head,,sizeof(head));
num=;
memset(a,,sizeof(a));
memset(ans,,sizeof(ans));
scanf("%d%d",&n,&m);
for (i=;i<=n;i++)
{
x=gi();
a[i][x]=;
}
for (i=;i<=n-;i++)
{
u=gi();v=gi();
add(u,v);add(v,u);
}
dfs(,);
for (i=;i<m-;i++)
printf("%d ",ans[i]);
printf("%d\n",ans[m-]);
}
}