
首先,让每一个叶节点做一次树根的话,每个路径一定至少有一次会变成直上直下的
于是对于每个叶节点作为根产生的20个trie树,把它们建到同一个广义SAM里
建法是对每个trie dfs去建,last就是父亲的那个节点;每次做一个新trie时,last给成root
然后答案就是每个节点表示的长度和
#include<bits/stdc++.h>
#define pa pair<int,int>
#define CLR(a,x) memset(a,x,sizeof(a))
#define MP make_pair
using namespace std;
typedef long long ll;
const int maxn=1e5+,maxp=4e6+; inline char gc(){
return getchar();
static const int maxs=<<;static char buf[maxs],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,,maxs,stdin),p1==p2)?EOF:*p1++;
}
inline ll rd(){
ll x=;char c=gc();bool neg=;
while(c<''||c>''){if(c=='-') neg=;c=gc();}
while(c>=''&&c<='') x=(x<<)+(x<<)+c-'',c=gc();
return neg?(~x+):x;
} int N,C,eg[maxn*][],egh[maxn],ect,dgr[maxn];
int col[maxn];
int tr[maxp][],fa[maxp],len[maxp],pct=,rt=; inline void adeg(int a,int b){
eg[++ect][]=b,eg[ect][]=egh[a],egh[a]=ect;
dgr[a]++;
} inline int insert(int x,int o){
int p=++pct;
len[p]=len[o]+;
for(;o&&!tr[o][x];o=fa[o]) tr[o][x]=p;
if(!o){fa[p]=rt;return p;}
int q=tr[o][x];
if(len[q]==len[o]+){fa[p]=q;return p;}
int qq=++pct;
memcpy(tr[qq],tr[q],sizeof(tr[qq]));
len[qq]=len[o]+;fa[qq]=fa[q],fa[q]=fa[p]=qq;
for(;o&&tr[o][x]==q;o=fa[o]) tr[o][x]=qq;
return p;
} inline void dfs(int x,int f,int p){
p=insert(col[x],p);
for(int i=egh[x];i;i=eg[i][]){
int b=eg[i][];if(b==f) continue;
dfs(b,x,p);
}
} int main(){
//freopen("","r",stdin);
int i,j,k;
N=rd(),C=rd();
for(i=;i<=N;i++) col[i]=rd();
for(i=;i<N;i++){
int a=rd(),b=rd();
adeg(a,b);adeg(b,a);
}
for(i=;i<=N;i++){
if(dgr[i]==) dfs(i,,rt);
}
ll ans=;
for(i=;i<=pct;i++) ans+=len[i]-len[fa[i]];
printf("%lld\n",ans);
return ;
}