BZOJ1906树上的蚂蚁&BZOJ3700发展城市——RMQ求LCA+树链的交

时间:2021-11-03 22:05:04

题目描述

众所周知,Hzwer学长是一名高富帅,他打算投入巨资发展一些小城市。
 Hzwer打算在城市中开N个宾馆,由于Hzwer非常壕,所以宾馆必须建在空中,但是这样就必须建立宾馆之间的连接通道。机智的Hzwer在宾馆中修建了N-1条隧道,也就是说,宾馆和隧道形成了一个树形结构。
 Hzwer有时候会花一天时间去视察某个城市,当来到一个城市之后,Hzwer会分析这些宾馆的顾客情况。对于每个顾客,Hzwer用三个数值描述他:(S, T, V)表示该顾客这天想要从宾馆S走到宾馆T,他的速度是V。
 Hzwer需要做一些收集一些数据,这样他就可以规划他接下来的投资。
 其中有一项数据就是收集所有顾客可能的碰面次数。
 每天清晨,顾客同时从S出发以V的速度前往T(注意S可能等于T),当到达了宾馆T的时候,顾客显然要找个房间住下,那么别的顾客再经过这里就不会碰面了。特别的,两个顾客同时到达一个宾馆是可以碰面的。同样,两个顾客同时从某宾馆出发也会碰面。

输入

第一行一个正整数T(1<=T<=20),表示Hzwer发展了T个城市,并且在这T个城市分别视察一次。
 对于每个T,第一行有一个正整数N(1<=N<=10^5)表示Hzwer在这个城市开了N个宾馆。
 接下来N-1行,每行三个整数X,Y,Z表示宾馆X和宾馆Y之间有一条长度为Z的隧道
 再接下来一行M表示这天顾客的数量。
 紧跟着M行每行三个整数(S, T, V)表示该顾客会从宾馆S走到宾馆T,速度为v

输出

对于每个T,输出一行,表示顾客的碰面次数。

样例输入

1
3
1 2 1
2 3 1
3
1 3 2
3 1 1
1 2 3

样例输出

2
0

提示

【数据规模】

1<=T<=20   1<=N<=10^5   0<=M<=10^3   1<=V<=10^6   1<=Z<=10^3

这题细节好多啊,蒟蒻的我调了一下午。

考虑到m的范围比较小,因此可以两两枚举判断是否相遇。

对于两个路径,如果能够相遇,相遇点一定在两个路径的交路径上。

如何求树上路径交?

对于两个路径A(a.u,a.v)与B(b.u,b.v)求出lca(a.u,b.u),lca(a.v,b.v),lca(a.v,b.u),lca(a.u,b.v)

去掉这四个点中不在A或B路径上的点,再去重后按dfs序排序,取后两个(如果只有一个说明路径只交于一点)就是交路径的两个端点

判断出两个路径起点先到达的交路径的端点是否是同一个,如果是就说明两个顾客是同向运动,反之则是相向运动。

如果两顾客是同向运动:只要先进入交路径的顾客后走出交路径就一定相遇。

如果两顾客是相向运动:分别求出两顾客进入和走出交路径的时间,判断只要两时间段有交集就能相遇,因为除法较慢,所以转成交叉相乘判断。

在判断和求路径过程中多次求lca,用O(logn)的方法求显然会TLE,在这里采用RMQ求lca:

在dfs时求出欧拉遍历序(就是遍历到一个点存一次)及每个点第一次被遍历的位置

对于x,y两点的lca就是欧拉序上两点第一次被遍历位置之间深度最小的点,用ST表即可O(1)查询

这道题有点卡常,注意涉及到乘速度时可能会爆longlong。

#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
inline char _read()
{
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int read()
{
int x=0,f=1;char ch=_read();
while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=_read();}
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=_read();}
return x*f;
}
int T,n,m;
int head[100010];
int s[100010];
int to[200010];
int next[200010];
int val[200010];
int d[100010];
int dep[100010];
int f[200010][18];
int g[200010][18];
int tot;
int num;
int x,y,z;
int ans;
int p[5];
int cnt;
int b[200010];
struct miku
{
int u,v,w;
}a[1010];
inline void add(int x,int y,int z)
{
tot++;
next[tot]=head[x];
head[x]=tot;
to[tot]=y;
val[tot]=z;
}
inline void dfs(int x,int fa)
{
d[x]=d[fa]+1;
s[x]=++num;
f[num][0]=d[x];
g[num][0]=x;
for(int i=head[x];i;i=next[i])
{
if(to[i]!=fa)
{
dep[to[i]]=dep[x]+val[i];
dfs(to[i],x);
f[++num][0]=d[x];
g[num][0]=x;
}
}
}
inline void ST()
{
for(int j=1;j<=17;j++)
{
for(int i=1;i<=num;i++)
{
if(i+(1<<j)-1>num)
{
break;
}
if(f[i][j-1]<f[i+(1<<(j-1))][j-1])
{
f[i][j]=f[i][j-1];
g[i][j]=g[i][j-1];
}
else
{
f[i][j]=f[i+(1<<(j-1))][j-1];
g[i][j]=g[i+(1<<(j-1))][j-1];
}
}
}
}
inline int lca(int x,int y)
{
x=s[x];
y=s[y];
if(x>y)
{
swap(x,y);
}
int len=b[y-x+1];
if(f[x][len]<f[y-(1<<len)+1][len])
{
return g[x][len];
}
else
{
return g[y-(1<<len)+1][len];
}
}
inline bool find(int anc,int x,int y)
{
int fx=lca(a[x].u,a[x].v);
int fy=lca(a[y].u,a[y].v);
if(lca(fx,anc)!=fx||lca(fy,anc)!=fy)
{
return false;
}
if(fx!=lca(fx,a[x].u)&&fx!=lca(fx,a[x].v))
{
return false;
}
if(fy!=lca(fy,a[y].u)&&fy!=lca(fy,a[y].v))
{
return false;
}
return true;
}
inline int dis(int x,int y)
{
int anc=lca(x,y);
return dep[x]+dep[y]-2*dep[anc];
}
inline bool cmp(int x,int y)
{
return s[x]<s[y];
}
inline bool cpr(ll a,ll b,ll c)
{
if(a<=b&&b<=c)
{
return 1;
}
else
{
return 0;
}
}
inline int check(int x,int y)
{
if(a[x].u==a[y].u)
{
return 1;
}
int res;
cnt=0;
res=lca(a[x].u,a[y].u);
if(find(res,x,y)){p[++cnt]=res;}
res=lca(a[x].v,a[y].v);
if(find(res,x,y)){p[++cnt]=res;}
res=lca(a[x].u,a[y].v);
if(find(res,x,y)){p[++cnt]=res;}
res=lca(a[y].u,a[x].v);
if(find(res,x,y)){p[++cnt]=res;}
if(cnt==0)
{
return 0;
}
sort(p+1,p+1+cnt,cmp);
cnt=unique(p+1,p+1+cnt)-p-1;
if(cnt==1)
{
if(1ll*dis(a[x].u,p[1])*a[y].w==1ll*dis(a[y].u,p[1])*a[x].w)
{
return 1;
}
else
{
return false;
}
}
int st=p[cnt];
int ed=p[cnt-1];
int A1,A2,B1,B2;
ll a1,a2,b1,b2;
if(dis(a[x].u,st)<dis(a[x].u,ed))
{
A1=st;
A2=ed;
}
else
{
A1=ed;
A2=st;
}
if(dis(a[y].u,st)<dis(a[y].u,ed))
{
B1=st;
B2=ed;
}
else
{
B1=ed;
B2=st;
}
a1=1ll*dis(a[x].u,A1)*a[y].w;
a2=1ll*dis(a[x].u,A2)*a[y].w;
b1=1ll*dis(a[y].u,B1)*a[x].w;
b2=1ll*dis(a[y].u,B2)*a[x].w;
if(A1==B1)
{
if(a1==b1)
{
return 1;
}
if(a1<b1)
{
return b2<=a2;
}
else
{
return a2<=b2;
}
}
else
{
if(cpr(a1,b1,a2))return 1;
if(cpr(a1,b2,a2))return 1;
if(cpr(b1,a1,b2))return 1;
if(cpr(b1,a2,b2))return 1;
return 0;
}
}
int main()
{
T=read();
b[0]=-1;
for(int i=1;i<=200010;i++)
{
b[i]=b[i>>1]+1;
}
while(T--)
{
memset(head,0,sizeof(head));
num=0;
tot=0;
ans=0;
n=read();
for(int i=1;i<n;i++)
{
x=read();
y=read();
z=read();
add(x,y,z);
add(y,x,z);
}
dfs(1,0);
ST();
m=read();
for(int i=1;i<=m;i++)
{
a[i].u=read();
a[i].v=read();
a[i].w=read();
}
for(int i=1;i<=m;i++)
{
for(int j=i+1;j<=m;j++)
{
ans+=check(i,j);
}
}
printf("%d\n",ans);
}
}