【Hihocoder 1167】 高等理论计算机科学 (树链的交,线段树或树状数组维护区间和)

时间:2023-03-08 17:25:23

【题意】

时间限制:20000ms
单点时限:1000ms
内存限制:256MB

描述

少女幽香这几天正在学习高等理论计算机科学,然而她什么也没有学会,非常痛苦。所以她出去晃了一晃,做起了一些没什么意义的事情来放松自己。
门前有一颗n个节点树,幽香发现这个树上有n个小精灵。然而这些小精灵都比较害羞,只会在一条特定的路径上活动。第i个小精灵会在ai到bi的路径上活动。
两个小精灵是朋友,当且仅当它们的路径是有公共点的。
于是幽香想要知道,有多少对小精灵a和b,a和b是朋友呢?其中a不等于b,a,b和b,a看做一对。

输入

第一行n和P (1 <= n, P <=100000),表示树的大小和小精灵的个数。树的节点从1到n标号。
接下来n-1行,每行两个数a,b,表示a到b之间有一条边。
接下来P行,第i行两个数ai,bi,表示小精灵i的活动范围是ai到bi,其中ai不等于bi。

输出

一行答案,表示对数。

样例输入
6 3
1 2
2 3
2 4
4 5
4 6
1 3
1 5
5 6
样例输出
2

【分析】

   ORZ 。。。

   我好蠢。。一直想树剖以及线段的交。。。【并且不是线段

  大神题解here:

  两条树链相交,当且仅当一条树链的lca在另一条树链上,对于每个树链,统计有多少个树链的lca在被他包含,有时两条树链互相满足这个条件,但仅仅当这两条树链的lca相等时才会有,所以特判一下

  上面说得很清楚了,区间的和维护用树状数组就可以了。

 #include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
using namespace std;
#define Maxn 100010
#define LL long long struct node
{
int x,y,next;
}t[Maxn*];int len; int first[Maxn],px[Maxn],py[Maxn]; void ins(int x,int y)
{
t[++len].x=x;t[len].y=y;
t[len].next=first[x];first[x]=len;
} int son[Maxn],dfn[Maxn],sm[Maxn],dep[Maxn],fa[Maxn];
void dfs1(int x,int f)
{
sm[x]=;son[x]=;dep[x]=dep[f]+;fa[x]=f;
for(int i=first[x];i;i=t[i].next) if(t[i].y!=f)
{
int y=t[i].y;
dfs1(y,x);
sm[x]+=sm[y];
if(sm[y]>sm[son[x]]) son[x]=y;
}
} int tp[Maxn],cnt;
void dfs2(int x,int f,int tpp)
{
dfn[x]=++cnt;tp[x]=tpp;
if(son[x]) dfs2(son[x],x,tpp);
for(int i=first[x];i;i=t[i].next) if(t[i].y!=f&&t[i].y!=son[x])
dfs2(t[i].y,x,t[i].y);
} int c[Maxn],n;
bool lca[Maxn]; void add(int x,int y)
{
for(int i=x;i<=n;i+=i&(-i))
c[i]+=y;
} int query(int l,int r)
{
int ans=;
for(int i=r;i>=;i-=i&(-i))
ans+=c[i];
l--;
for(int i=l;i>=;i-=i&(-i))
ans-=c[i];
return ans;
} int gans(int x,int y,int p)
{
int ans=,tt;
while(tp[x]!=tp[y])
{
if(dep[tp[x]]<dep[tp[y]]) tt=x,x=y,y=tt;
if(p==) ans+=query(dfn[tp[x]],dfn[x]);
x=fa[tp[x]];
}
if(dep[x]<dep[y]) tt=x,x=y,y=tt;
if(p==)
{
ans+=query(dfn[y],dfn[x]);
return ans;
}
else return y;
} int main()
{
int p;
LL ans=;
scanf("%d%d",&n,&p);
len=;
memset(first,,sizeof(first));
for(int i=;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
ins(x,y);ins(y,x);
}
for(int i=;i<=p;i++) scanf("%d%d",&px[i],&py[i]);
sm[]=;dep[]=;
dfs1(,);cnt=;
dfs2(,,);
memset(c,,sizeof(c));
memset(lca,,sizeof(lca));
for(int i=;i<=p;i++)
{
int x=gans(px[i],py[i],);
lca[x]=;
add(dfn[x],);
}
for(int i=;i<=p;i++)
{
int x=gans(px[i],py[i],);
ans+=x;
}
for(int i=;i<=n;i++) if(lca[i])
{
int x=query(dfn[i],dfn[i]);
ans-=x*(x-)/+x;
}
printf("%lld\n",ans);
return ;
}

2016-11-10 18:17:12