题意:有一颗树,每个点有一个权值和一个字符串,要求计算出以每个点的子树的贡献,贡献的定义是两个点权值的xor*两个点字符串的lcp。n<=1e5
其实这题我第一眼就想到trie,但是我trie基本上没做过多少题,不会xor统计的那种科技(这个太基础了吧喂),然后就异想天开用了个SA,结果爆炸,调了半天调不出来,心情复杂。
正解是trie合并。开两颗trie,一颗记录lcp,一颗记录每个子树内有多少个点的权值在第i位为1或者0.那么计算贡献和合并都很好写了,lcp的话,因为是trie,我每次转移同一条边就可以直接lcp+1了。
至于贡献,trie的那个东西是符合前缀和性质,我把一颗子树减掉然后就是其他子树的了,那么我们直接算就可以,这个东西自己手推就出来了,就是每个子树的01之间互相乘起来。
#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
const int N=6e5+5;
int n,m;
typedef long long ll;
int mi[20],sz,cnt,root[N],ret[20];
ll tot,ans[N];
char ch[N];
int head[N],next[N],go[N],a[N];
struct node
{
int size,ch[26],cnt[17],tag;
}t[N];
inline void add(int x,int y)
{
go[++cnt]=y;
next[cnt]=head[x];
head[x]=cnt;
}
inline void ins(int &x,int l,int r,int v)
{
x=++sz;
t[x].size++;
fo(i,0,16)
if (v&mi[i])t[x].cnt[i]++;
if (l!=r)ins(t[x].ch[ch[l]-'a'],l+1,r,v);
else t[x].tag++;
}
inline int merge(int x,int y,int z)
{
if (!x||!y)return x^y;
fo(i,0,16)ret[i]=t[x].cnt[i];
fo(i,0,25)
if (t[x].ch[i])
{
int v=t[x].ch[i];
fo(j,0,16)
{
tot+=1ll*mi[j]*z*(1ll*t[v].cnt[j]*(t[y].size-t[t[y].ch[i]].size-t[y].cnt[j]+t[t[y].ch[i]].cnt[j]));
tot+=1ll*mi[j]*z*(1ll*(t[v].size-t[v].cnt[j])*(t[y].cnt[j]-t[t[y].ch[i]].cnt[j]));
ret[j]-=t[v].cnt[j];
}
}
fo(i,0,16)
{
tot+=1ll*mi[i]*z*ret[i]*(t[y].size-t[y].cnt[i]);
tot+=1ll*mi[i]*z*(t[x].tag-ret[i])*t[y].cnt[i];
}
t[x].size+=t[y].size;
t[x].tag+=t[y].tag;
fo(i,0,16)t[x].cnt[i]+=t[y].cnt[i];
fo(i,0,25)t[x].ch[i]=merge(t[x].ch[i],t[y].ch[i],z+1);
return x;
}
inline void solve(int x,int fa)
{
for(int i=head[x];i;i=next[i])
{
int v=go[i];
if (v==fa)continue;
solve(v,x);
ans[x]+=ans[v];
tot=0;
root[x]=merge(root[x],root[v],0);
ans[x]+=tot;
}
}
int main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
scanf("%d",&n);
fo(i,1,n)scanf("%d",&a[i]);
mi[0]=1;
fo(i,1,16)mi[i]=1ll*mi[i-1]*2;
fo(i,1,n)
{
scanf("%s",ch);
int len=strlen(ch);
ins(root[i],0,len,a[i]);
}
fo(i,1,n-1)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
solve(1,0);
fo(i,1,n)printf("%lld\n",ans[i]);
}