Tree Cutting
Byteasar has a tree TTT with nnn vertices conveniently labeled with 1,2,...,n1,2,...,n1,2,...,n. Each vertex of the tree has an integer value viv_ivi.
The value of a non-empty tree TTT is equal to v1⊕v2⊕...⊕vnv_1\oplus v_2\oplus ...\oplus v_nv1⊕v2⊕...⊕vn, where ⊕\oplus⊕ denotes bitwise-xor.
Now for every integer kkk from [0,m)[0,m)[0,m), please calculate the number of non-empty subtree of TTT which value are equal to kkk.
A subtree of TTT is a subgraph of TTT that is also a tree.
The first line of the input contains an integer T(1≤T≤10)T(1\leq T\leq10)T(1≤T≤10), denoting the number of test cases.
In each test case, the first line of the input contains two integers n(n≤1000)n(n\leq 1000)n(n≤1000) and m(1≤m≤210)m(1\leq m\leq 2^{10})m(1≤m≤210), denoting the size of the tree TTT and the upper-bound of vvv.
The second line of the input contains nnn integers v1,v2,v3,...,vn(0≤vi<m)v_1,v_2,v_3,...,v_n(0\leq v_i < m)v1,v2,v3,...,vn(0≤vi<m), denoting the value of each node.
Each of the following n−1n-1n−1 lines contains two integers ai,bia_i,b_iai,bi, denoting an edge between vertices aia_iai and bi(1≤ai,bi≤n)b_i(1\leq a_i,b_i\leq n)bi(1≤ai,bi≤n).
It is guaranteed that mmm can be represent as 2k2^k2k, where kkk is a non-negative integer.
For each test case, print a line with mmm integers, the iii-th number denotes the number of non-empty subtree of TTT which value are equal to iii.
The answer is huge, so please module 109+710^9+7109+7.
2
4 4
2 0 1 3
1 2
1 3
1 4
4 4
0 1 3 1
1 2
1 3
1 4
3 3 2 3
2 4 2 3 分析:dp[i][j]表示以i为根异或值为j的方案数;
在加入i的儿子x的子树方案时,dp[i][j]=dp[i][j]+dp[i][k]*dp[x][t](k^t=j);
其中dp[i][k]*dp[x][t](k^t=j)的复杂度为n²,可以用异或卷积加速到nlogn;
代码:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <climits>
#include <cstring>
#include <string>
#include <set>
#include <map>
#include <queue>
#include <stack>
#include <vector>
#include <list>
#define vi vector<int>
#define rep(i,m,n) for(i=m;i<=n;i++)
#define rsp(it,s) for(set<int>::iterator it=s.begin();it!=s.end();it++)
#define mod 1000000007
#define rev (mod+1)/2
#define inf 0x3f3f3f3f
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define ll long long
#define pi acos(-1.0)
using namespace std;
const int maxn=1e3+;
int n,m,k,t,a[maxn],dp[maxn][maxn],tmp[maxn],ans[maxn];
vi e[maxn];
void fwt(int *a,int n)
{
for(int d=;d<n;d<<=)
for(int m=d<<,i=;i<n;i+=m)
for(int j=;j<d;j++)
{
int x=a[i+j],y=a[i+j+d];
a[i+j]=(x+y)%mod,a[i+j+d]=(x-y+mod)%mod;
}
}
void ufwt(int *a,int n)
{
for(int d=;d<n;d<<=)
for(int m=d<<,i=;i<n;i+=m)
for(int j=;j<d;j++)
{
int x=a[i+j],y=a[i+j+d];
a[i+j]=1LL*(x+y)*rev%mod,a[i+j+d]=(1LL*(x-y)*rev%mod+mod)%mod;
}
}
void solve(int *a,int *b,int n)
{
fwt(a,n);
fwt(b,n);
for(int i=;i<n;i++)a[i]=1LL*a[i]*b[i]%mod;
ufwt(a,n);
}
void dfs(int now,int pre)
{
dp[now][a[now]]=;
for(int x:e[now])
{
if(x==pre)continue;
dfs(x,now);
for(int i=;i<m;i++)
tmp[i]=dp[now][i];
solve(dp[now],dp[x],m);
for(int i=;i<m;i++)
dp[now][i]=(dp[now][i]+tmp[i])%mod;
}
for(int i=;i<m;i++)
ans[i]=(ans[i]+dp[now][i])%mod;
}
int main()
{
int i,j;
scanf("%d",&t);
while(t--)
{
scanf("%d%d",&n,&m);
rep(i,,n)
{
scanf("%d",&a[i]),e[i].clear();
rep(j,,m-)dp[i][j]=;
}
rep(i,,m-)ans[i]=;
rep(i,,n-)
{
scanf("%d%d",&j,&k);
e[j].pb(k),e[k].pb(j);
}
dfs(,);
rep(i,,m-)printf("%d%c",ans[i],i<m-?' ':'\n');
}
//system ("pause");
return ;
}