[LOJ3088][GXOI/GZOI2019]旧词——树链剖分+线段树

时间:2022-04-14 16:07:10

题目链接:

[GXOI/GZOI2019]旧词

对于$k=1$的情况,可以参见[LNOI2014]LCA,将询问离线然后从$1$号点开始对这个点到根的路径链修改,每次询问就是对询问点到根路径链查询即可。

可以发现,如果一个点的贡献被记入答案,那么这个点到根的路径上所有点的贡献都会被记入答案。

那么对于$k>1$的情况,只要每次将路径上点$u$的权值都$+1$变成每次将路径上点$u$的权值都$+(dep[u]^k-(dep[u]-1)^k)$即可。

同样用线段树维护树剖序的区间权值和即可。

#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<bitset>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int mod=998244353;
int n,m,k;
int ans[50010];
int p[50010];
int son[50010];
int size[50010];
int f[50010];
int tot;
int head[50010];
int to[50010];
int nex[50010];
int dep[50010];
int top[50010];
int s[50010];
int q[50010];
int dfn;
int sum[400010];
int num[400010];
int tag[400010];
struct lty
{
int x,y,id;
}a[50010];
bool cmp(lty a,lty b)
{
return a.x<b.x;
}
int quick(int x,int y)
{
int res=1;
while(y)
{
if(y&1)
{
res=1ll*res*x%mod;
}
x=1ll*x*x%mod;
y>>=1;
}
return res;
}
void add_edge(int x,int y)
{
nex[++tot]=head[x];
head[x]=tot;
to[tot]=y;
}
int add(int x,int y)
{
if(x+y<mod)
{
return x+y;
}
else
{
return x+y-mod;
}
}
void dfs(int x)
{
size[x]=1;
for(int i=head[x];i;i=nex[i])
{
dep[to[i]]=dep[x]+1;
dfs(to[i]);
size[x]+=size[to[i]];
if(size[to[i]]>size[son[x]])
{
son[x]=to[i];
}
}
}
void dfs2(int x,int tp)
{
top[x]=tp;
s[x]=++dfn;
q[dfn]=x;
if(son[x])
{
dfs2(son[x],tp);
}
for(int i=head[x];i;i=nex[i])
{
if(to[i]!=son[x])
{
dfs2(to[i],to[i]);
}
}
}
void pushup(int rt)
{
sum[rt]=add(sum[rt<<1],sum[rt<<1|1]);
num[rt]=add(num[rt<<1],num[rt<<1|1]);
}
void pushdown(int rt)
{
if(tag[rt])
{
tag[rt<<1]=add(tag[rt],tag[rt<<1]);
tag[rt<<1|1]=add(tag[rt],tag[rt<<1|1]);
sum[rt<<1]=add(sum[rt<<1],1ll*tag[rt]*num[rt<<1]%mod);
sum[rt<<1|1]=add(sum[rt<<1|1],1ll*tag[rt]*num[rt<<1|1]%mod);
tag[rt]=0;
}
}
void build(int rt,int l,int r)
{
if(l==r)
{
num[rt]=p[dep[q[l]]];
return ;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void change(int rt,int l,int r,int L,int R)
{
if(L<=l&&r<=R)
{
tag[rt]=add(tag[rt],1);
sum[rt]=add(sum[rt],num[rt]);
return ;
}
int mid=(l+r)>>1;
pushdown(rt);
if(L<=mid)
{
change(rt<<1,l,mid,L,R);
}
if(R>mid)
{
change(rt<<1|1,mid+1,r,L,R);
}
pushup(rt);
}
int query(int rt,int l,int r,int L,int R)
{
if(L<=l&&r<=R)
{
return sum[rt];
}
int mid=(l+r)>>1;
int res=0;
pushdown(rt);
if(L<=mid)
{
res=add(res,query(rt<<1,l,mid,L,R));
}
if(R>mid)
{
res=add(res,query(rt<<1|1,mid+1,r,L,R));
}
return res;
}
void modify(int x)
{
while(top[x]!=1)
{
change(1,1,n,s[top[x]],s[x]);
x=f[top[x]];
}
change(1,1,n,1,s[x]);
}
int ask(int x)
{
int res=0;
while(top[x]!=1)
{
res=add(res,query(1,1,n,s[top[x]],s[x]));
x=f[top[x]];
}
res=add(res,query(1,1,n,1,s[x]));
return res;
}
int main()
{
scanf("%d%d%d",&n,&m,&k);
for(int i=1;i<=n;i++)
{
p[i]=(quick(i,k)-quick(i-1,k)+mod)%mod;
}
dep[1]=1;
for(int i=2;i<=n;i++)
{
scanf("%d",&f[i]);
add_edge(f[i],i);
}
dfs(1);
dfs2(1,1);
build(1,1,n);
for(int i=1;i<=m;i++)
{
scanf("%d%d",&a[i].x,&a[i].y);
a[i].id=i;
}
sort(a+1,a+1+m,cmp);
int now=0;
for(int i=1;i<=m;i++)
{
while(now<a[i].x)
{
now++;
modify(now);
}
ans[a[i].id]=ask(a[i].y);
}
for(int i=1;i<=m;i++)
{
printf("%d\n",ans[i]);
}
}