Autostrady
https://szkopul.edu.pl/problemset/problem/f2dSBM7JteWHqtmVejMWe1bW/site/?key=statement
题意:
首先给定一棵树,除了n-1条树边以外,还有m条非树边。每次询问两个点的满足以下条件的路径条数。
- 不能走树上u到v的简单路径的边。
- 只能走一条非树边。
分析:
RMQ求LCA + 线段树合并。
问题转化为有多少边的一个端点在u的子树内,另一个在v的子树内。
每个询问只在深度大的询问加入深度小的。对每个节点建立一个权值线段树,dfs从叶子节点往上合并,每到一个节点询问一段区间的数。
如果询问一个是另一个的祖先,要特判。
代码:
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<iostream>
#include<cctype>
#include<set>
#include<vector>
#include<queue>
#include<map>
#define pa pair<int,int>
#define mp(a,b) make_pair(a,b)
using namespace std;
typedef long long LL; inline int read() {
int x=,f=;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-;
for(;isdigit(ch);ch=getchar())x=x*+ch-'';return x*f;
} const int N = ; int deth[N], id[N], siz[N], ans[N * ], Time_Index, n; // ans[N * 5] !!!
vector<int> adj[N], ext[N];
vector< pa > q[N]; namespace LCA{
int ord[N << ], d[N << ], p[N], f[N << ][], Log[N << ], tot = ;
void dfs(int u,int fa,int dep) {
ord[++ tot] = u; d[tot] = dep;
p[u] = tot;
for (int i=,sz=adj[u].size(); i<sz; ++i) {
int v = adj[u][i];
if (v == fa) continue;
dfs(v, u, dep + );
ord[++tot] = u; d[tot] = dep;
}
}
void init() {
Log[] = -;
for (int i=; i<=tot; ++i)
f[i][] = i, Log[i] = Log[i >> ] + ; // i<<1 !!!
for (int j=; j<=Log[tot]; ++j) {
for (int i=; (i+(<<j)-)<=tot; ++i) {
int x = f[i][j - ], y = f[i+(<<(j-))][j - ]; // j-1 !!!
f[i][j] = d[x] < d[y] ? x : y;
}
}
}
int Lca(int u,int v) {
u = p[u], v = p[v];
if (u > v) swap(u, v);
int k = Log[v - u + ];
int x = f[u][k], y = f[v-(<<k)+][k];
return d[x] < d[y] ? ord[x] : ord[y];
}
void Main() {
tot = ; dfs(, , ); init();
}
} namespace SegmentTree{
queue<int> s;
int sum[N * ], ls[N * ], rs[N * ], tot;
int NewNode() {
int k;
if (!s.empty()) k = s.front(), s.pop();
else k = ++tot;
sum[k] = ls[k] = rs[k] = ;
return k;
}
void Insert(int l,int r,int &rt,int p) {
if (!rt) rt = NewNode();
if (l == r) {
sum[rt] ++; return ;
}
int mid = (l + r) >> ;
if (p <= mid) Insert(l, mid, ls[rt], p);
else Insert(mid + , r, rs[rt], p);
sum[rt] = sum[ls[rt]] + sum[rs[rt]];
}
int query(int l,int r,int rt,int L,int R) {
if (!rt) return ; // !!!
if (L <= l && r <= R) return sum[rt];
int mid = (l + r) >> , res = ;
if (L <= mid) res = query(l, mid, ls[rt], L, R);
if (R > mid) res += query(mid + , r, rs[rt], L, R);
return res;
}
int Merge(int x,int y) {
if (!x || !y) return x + y;
ls[x] = Merge(ls[x], ls[y]);
rs[x] = Merge(rs[x], rs[y]);
sum[x] = sum[x] + sum[y]; // sum[x] = sum[ls[x]] + sum[rs[x]]; !!!
s.push(y);
return x;
}
int solve(int u,int fa) {
int rt = NewNode();
for (int i=,sz=adj[u].size(); i<sz; ++i)
if (adj[u][i] != fa) rt = Merge(rt, solve(adj[u][i], u));
for (int i=,sz=ext[u].size(); i<sz; ++i)
Insert(, n, rt, id[ext[u][i]]);
for (int i=,sz=q[u].size(); i<sz; ++i) {
int v = q[u][i].first;
if (LCA::Lca(u, v) == v) {
for (int t,j=; j<adj[v].size(); ++j)
if ((t=adj[v][j]) == LCA::Lca(u, t) && deth[t] > deth[v]) {// deth[t] > deth[v] !!!
ans[q[u][i].second] = query(, n, rt, , n) - query(, n, rt, id[t], id[t] + siz[t] - );
break;
}
}
else ans[q[u][i].second] = query(, n, rt, id[v], id[v] + siz[v] - );
}
return rt;
}
}
void dfs(int u,int fa) {
deth[u] = deth[fa] + ;
siz[u] = ;
id[u] = ++Time_Index;
for (int i=,sz=adj[u].size(); i<sz; ++i) {
int v = adj[u][i];
if (v == fa) continue;
dfs(v, u);
siz[u] += siz[v];
}
} int main() {
n = read();
for (int i=; i<n; ++i) {
int u = read(), v = read();
adj[u].push_back(v), adj[v].push_back(u);
}
int m = read();
for (int i=; i<=m; ++i) {
int u = read(), v = read();
ext[u].push_back(v), ext[v].push_back(u);
}
dfs(, );
int Q = read();
for (int i=; i<=Q; ++i) {
int u = read(), v = read();
if (deth[u] > deth[v]) q[u].push_back(mp(v,i));
else q[v].push_back(mp(u, i));
}
LCA::Main();
SegmentTree::solve(, );
for (int i=; i<=Q; ++i)
printf("%d\n", ans[i] + );
return ;
}