BZOJ3772 精神污染 主席树 dfs序

时间:2022-07-11 21:33:51

欢迎访问~原文出处——博客园-zhouzhendong

去博客园看该题解


题目传送门 - BZOJ3772


题意概括

  给出一个树,共n个节点。

  有m条互不相同的树上路径。

  现在让你随机选择2条路径,问两条路径存在包含关系的概率(输出最简分数)。

  n,m<=100000


题解

  首先,暴力肯定过不去的。

  然后,我们发现总选择的方案数是C(m,2)

  然后重点是统计包含关系的。

  现在,我们有一个做法。

  我们先把整个树的dfs序搞出来。

  然后,相当于某一个子树就是连续的一段。对于输入的每一个路径(x,y),我们在x处打一个y标记,在y处打一个x标记。然后比如我们要搜寻包含路径(a,b)的路径,那么只需要在保证其他的x,y分别处于ab两侧,我们只需要统计在a一侧的标记在b一侧有多少对应的标记即可。

  于是我们用到了主席树。(你要写线段树套线段树我也不拦你)。

  主席树的时间和区间各表示一种标记。

  比如在路径(x,y),那么就在时间x的时候区间[y,y]加1。

  如果不大懂可以参见其他大佬的博客。

  标记打好之后是关键部分。

  对于每一条路径(a,b),我们分类讨论。

  设LCA(a,b)=c

  情况1:

    a≠c且b≠c:

  如图:

    BZOJ3772 精神污染 主席树 dfs序

  那么,只需要统计在a的子树中的节点所对应的标记在b的子树中有几个即可。别忘记减掉它本身。(-1即可)

  情况2:a,b中有一个=c,不妨设b=c

  如图:

BZOJ3772 精神污染 主席树 dfs序

  那么,我们发现,从子树a的节点出发,既要统计b的爸爸延伸出去的(绿色路径),又要统计b除了到a路径上的儿子以外的其他儿子的(如蓝色路径),貌似很麻烦。

  实际上,就是全局的减去b到a路径上的b的儿子的。至于这个儿子,倍增就可以求了。别忘了-1。

  情况3:a=b=c

  这个很明显就是就a的子树节点出发,统计全局除了a子树的答案。

  

  一切的一切,主席树统统搞定。

  最后,提供一组数据:

11 7
1 2
1 3
1 8
2 4
2 5
3 6
3 7
3 11
6 9
6 10

2 2
4 8
2 1
3 2
3 10
9 10
9 11

ans=5/21


代码

#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <vector>
using namespace std;
typedef long long LL;
LL gcd(LL a,LL b){return b?gcd(b,a%b):a;}
const int N=100005;
struct Gragh{
int cnt,y[N*2],nxt[N*2],fst[N];
void clear(){
cnt=0;
memset(fst,0,sizeof fst);
}
void add(int a,int b){
y[++cnt]=b,nxt[cnt]=fst[a],fst[a]=cnt;
}
}g;
int n,m,time;
int dfn[N],in[N],out[N],fa[N][20],depth[N];
vector <int> v[N];
struct Que{
int a,b;
}q[N];
void dfs(int rt,int pre){
depth[rt]=depth[pre]+1;
fa[rt][0]=pre;
for (int i=1;i<20;i++)
fa[rt][i]=fa[fa[rt][i-1]][i-1];
dfn[in[rt]=++time]=rt;
for (int i=g.fst[rt];i;i=g.nxt[i])
if (g.y[i]!=pre)
dfs(g.y[i],rt);
out[rt]=time;
}
bool isfa(int a,int b){
return in[a]<=in[b]&&out[b]<=out[a];
}
int LCS(int a,int b){
for (int i=19;i>=0;i--)
if (fa[a][i]&&!isfa(fa[a][i],b))
a=fa[a][i];
return a;
}
int LCA(int a,int b){
if (isfa(a,b))
return a;
if (isfa(b,a))
return b;
return fa[LCS(a,b)][0];
}
const int S=N*2*20;
int ls[S],rs[S],sum[S],total=0,root[N];
void build(int &rt,int L,int R){
rt=++total;
sum[rt]=0;
if (L==R)
return;
int mid=(L+R)>>1;
build(ls[rt],L,mid);
build(rs[rt],mid+1,R);
}
void add(int prt,int &rt,int L,int R,int pos){
if (!rt||rt==prt)
rt=++total,sum[rt]=sum[prt];
sum[rt]++;
if (L==R)
return;
if (!ls[rt])
ls[rt]=ls[prt];
if (!rs[rt])
rs[rt]=rs[prt];
int mid=(L+R)>>1;
if (pos<=mid)
add(ls[prt],ls[rt],L,mid,pos);
else
add(rs[prt],rs[rt],mid+1,R,pos);
}
int query(int prt,int rt,int L,int R,int xL,int xR){
if (xL>R||xR<L)
return 0;
if (xL<=L&&R<=xR)
return sum[rt]-sum[prt];
int mid=(L+R)>>1;
return query(ls[prt],ls[rt],L,mid,xL,xR)
+query(rs[prt],rs[rt],mid+1,R,xL,xR);
}
int main(){
g.clear();
scanf("%d%d",&n,&m);
for (int i=1,a,b;i<n;i++){
scanf("%d%d",&a,&b);
g.add(a,b);
g.add(b,a);
}
time=0;
dfs(1,0);
for (int i=1;i<=n;i++)
v[i].clear();
for (int i=1,a,b;i<=m;i++){
scanf("%d%d",&a,&b);
if (in[a]>in[b])
swap(a,b);
v[in[a]].push_back(in[b]);
v[in[b]].push_back(in[a]);
q[i].a=a,q[i].b=b;
}
build(root[0],1,n);
for (int i=1;i<=n;i++){
root[i]=root[i-1];
for (int j=0;j<v[i].size();j++)
add(root[i-1],root[i],1,n,v[i][j]);
}
LL x=0,y=1LL*m*(m-1)/2;
for (int i=1;i<=m;i++){
int a=q[i].a,b=q[i].b,c=LCA(a,b);
if (a!=c&&b!=c){
x+=query(root[in[a]-1],root[out[a]],1,n,in[b],out[b]);
x--;
}
else if (a!=c||b!=c){
if (b!=c)
swap(a,b);
int d=LCS(a,b);
x+=query(root[in[a]-1],root[out[a]],1,n,1,n);
x-=query(root[in[a]-1],root[out[a]],1,n,in[d],out[d]);
x--;
}
else {
x+=query(root[in[a]-1],root[out[a]],1,n,1,n);
x-=query(root[in[a]-1],root[out[a]],1,n,in[a],out[a]);
}
}
LL g=gcd(y,x);
x/=g,y/=g;
if (x==0)
puts("0");
else
printf("%lld/%lld\n",x,y);
return 0;
}