luogu4365 秘密袭击 (生成函数+线段树合并+拉格朗日插值)

时间:2024-04-30 08:11:45

求所有可能联通块的第k大值的和,考虑枚举这个值:

$ans=\sum\limits_{i=1}^{W}{i\sum\limits_{S}{[i是第K大]}}$

设cnt[i]为连通块中值>=i的个数

$ans=\sum\limits_{i=1}^{W}{i\sum\limits_{S}{[cnt[i]>=K]-[cnt[i+1]>=K]}}$

$ans=\sum\limits_{i=1}^{W}{\sum\limits_{S}{[cnt[i]>=K]}}$

于是先考虑树上dp,设f[i][j][k]表示以i为根的连通块中,值>=j的数量为k的情况数

然后$ans=\sum\limits_{i=1}^{N}{\sum\limits_{j=1}^{W}{\sum\limits_{k=K}^{N}{f[i][j][k]}}}$

转移和背包类似,所以这样做是$O(N^2W)$的

考虑使用生成函数优化,设$F[i][j]=\sum{f[i][j][k]x^k}$,再设$G[i][j]=\sum{F[s][j]},i是s的祖先$

于是转移就变成了$F[i][j]*=(F[s][j]+1),G[i][j]+=G[s][j],G[i][j]+=F[i][j]$,其中s是i的孩子

同时有初值$F[i][j]=(d[i]>=j?x:1)$,答案就是G[1][*]的K~N项系数的和

然后当然不能真的去乘了..

考虑先将F和G用点值表达,最后再插回来

首先枚举x=1..N+1,然后给每个点i开动态开点的线段树维护F[i][j]和G[i][j]的值

然后用线段树合并来做对应位置的相乘和相加

具体来说,我们让线段树上的结点维护一个作用在$(f,g)$上的变换$(a,b,c,d)$,使得最终得到$(af+b,cf+d+g)$

然后也不难得到变换的乘法(有结合律但没有交换律)

然后就可以做了 复杂度我也不会分析 反正有可能跑的比暴力还慢

别忘了回收掉不用的点

 #include<bits/stdc++.h>
#define pa pair<int,int>
#define CLR(a,x) memset(a,x,sizeof(a))
#define MP make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int ui;
typedef long double ld;
const int maxn=,maxp=3e6;
const int P=; inline char gc(){
return getchar();
static const int maxs=<<;static char buf[maxs],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,,maxs,stdin),p1==p2)?EOF:*p1++;
}
inline ll rd(){
ll x=;char c=gc();bool neg=;
while(c<''||c>''){if(c=='-') neg=;c=gc();}
while(c>=''&&c<='') x=(x<<)+(x<<)+c-'',c=gc();
return neg?(~x+):x;
} struct Node{
int a,b,c,d;
Node(int _a=,int _b=,int _c=,int _d=){a=_a,b=_b,c=_c,d=_d;}
}val[maxp];
Node operator *(Node x,Node y){
return Node(1ll*x.a*y.a%P,(1ll*x.b*y.a+y.b)%P,(1ll*x.a*y.c+x.c)%P,(1ll*x.b*y.c+x.d+y.d)%P);
} int N,K,W,dan[maxn],eg[maxn*][],egh[maxn],ect;
int ch[maxp][],stk[maxp],sh,rt[maxn];
int yy[maxn]; inline void adeg(int a,int b){
eg[++ect][]=b,eg[ect][]=egh[a],egh[a]=ect;
} inline int newnode(){
int p=stk[sh--];
assert(sh>=);
ch[p][]=ch[p][]=;
val[p]=Node();
return p;
} inline void delall(int &p){
if(!p) return;
delall(ch[p][]);delall(ch[p][]);
stk[++sh]=p;p=;
} inline void pushdown(int p){
if(!ch[p][]) ch[p][]=newnode();
if(!ch[p][]) ch[p][]=newnode();
val[ch[p][]]=val[ch[p][]]*val[p];
val[ch[p][]]=val[ch[p][]]*val[p];
val[p]=Node();
} void mul(int &p,int l,int r,int x,int y,Node z){
if(!p) p=newnode();
if(x<=l&&r<=y){
val[p]=val[p]*z;
}else{
int m=(l+r)>>;pushdown(p);
if(x<=m) mul(ch[p][],l,m,x,y,z);
if(y>=m+) mul(ch[p][],m+,r,x,y,z);
}
} int merge(int &p,int &q){
if(!p||!q) return p|q;
if(!ch[p][]&&!ch[p][]) swap(p,q);
if(!ch[q][]&&!ch[q][]){
val[p]=val[p]*Node(val[q].b,,,val[q].d);
return p;
}
pushdown(p),pushdown(q);
ch[p][]=merge(ch[p][],ch[q][]);
ch[p][]=merge(ch[p][],ch[q][]);
return p;
} void dfs(int x,int f,int id){
mul(rt[x],,W,,W,Node(,,,));
for(int i=egh[x];i;i=eg[i][]){
int b=eg[i][];if(b==f) continue;
dfs(b,x,id);
merge(rt[x],rt[b]);
delall(rt[b]);
}
mul(rt[x],,W,,dan[x],Node(id,,,));
mul(rt[x],,W,,W,Node(,,,));
mul(rt[x],,W,,W,Node(,,,));
} int query(int p,int l,int r){
if(!p) return ;
if(l==r) return val[p].d;
int m=(l+r)>>;pushdown(p);
return (query(ch[p][],l,m)+query(ch[p][],m+,r))%P;
} int fpow(int x,int y){
int r=;
while(y){
if(y&) r=1ll*r*x%P;
x=1ll*x*x%P,y>>=;
}return r;
} int l[maxn],tmp[maxn],ans[maxn];
void calc(){
l[]=;
for(int i=;i<=N+;i++){
for(int j=i-;j>=;j--){
l[j+]=(l[j+]+l[j])%P;
l[j]=-1ll*i*l[j]%P;
}
}
for(int i=;i<=N+;i++){
int ib=-fpow(i,P-);
tmp[]=1ll*l[]*ib%P;
for(int j=;j<=N;j++){
tmp[j]=1ll*(l[j]-tmp[j-])*ib%P;
}
int k=,x=;
for(int j=;j<=N;j++){
k=(1ll*x*tmp[j]+k)%P;
x=1ll*x*i%P;
}
k=1ll*fpow(k,P-)*yy[i]%P;
for(int j=;j<=N;j++){
ans[j]=(1ll*tmp[j]*k+ans[j])%P;
}
}
} int main(){
//freopen("","r",stdin);
N=rd(),K=rd(),W=rd();
for(int i=;i<=N;i++) dan[i]=rd();
for(int i=;i<N;i++){
int a=rd(),b=rd();
adeg(a,b);adeg(b,a);
}
for(int i=;i<maxp-;i++) stk[++sh]=i; for(int i=;i<=N+;i++){
dfs(,,i);
yy[i]=query(rt[],,W);
delall(rt[]);
}
calc();
int a=;
for(int i=K;i<=N;i++) a=(a+ans[i])%P;
printf("%d\n",(a+P)%P);
return ;
}