给一棵 \(n\) 个点的树加上 \(m\) 条非树边 , 现在需要断开一条树边和一条非树边使得图不连通 , 求方案数 .
$n \le 10^5 , m \le 2*10^5 $ , 保证答案在 \(int\) 范围内.
对于每条非树边 , 覆盖 \(x\) 到 \(LCA\) 和 \(y\)到 \(LCA\) 的边 , 即差分算出每个点和父亲的连边被覆盖了多少次 .
被覆盖 \(0\) 次的边可以和 \(m\) 条非树边搭配 , 被覆盖 \(1\) 次的边可以和唯一的非树边搭配 , \(2\) 次以上的不能产生贡献 .
时间复杂度 \(O(n+m)\)
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define Debug(x) cout<<#x<<"="<<x<<endl
#define log2 LLLLog2
using namespace std;
typedef long long LL;
const int INF=1e9+7;
inline LL read(){
register LL x=0,f=1;register char c=getchar();
while(c<48||c>57){if(c=='-')f=-1;c=getchar();}
while(c>=48&&c<=57)x=(x<<3)+(x<<1)+(c&15),c=getchar();
return f*x;
}
const int N = 3e5 + 5;
const int M = 6e5 + 5;
const int logN = 20;
struct Edge{
int v,w,nxt;
}e[M];
int first[N],Ecnt=0;
inline void Add_edge(int u,int v,int w=0){
e[++Ecnt]=(Edge){v,w,first[u]};
first[u]=Ecnt;
}
int fa[N][logN], dep[N], tag[N], log2[N];
int n, m, ans;
inline void dfs1(int u, int pre){
fa[u][0] = pre, dep[u] = dep[pre] + 1;
for(int i = 1; fa[u][i - 1]; ++i){
fa[u][i] = fa[fa[u][i - 1]][i - 1];
}
for(int i = first[u]; i; i = e[i].nxt){
int v = e[i].v;
if(v != pre) dfs1(v, u);
}
}
inline int LCA(int x, int y){
if(dep[x] < dep[y]) swap(x, y);
for(int i = log2[dep[x] - dep[y]]; i >= 0; --i){
if(dep[fa[x][i]] >= dep[y]){
x = fa[x][i];
}
}
if(x == y) return x;
for(int i = log2[dep[x]]; i >= 0; --i){
if(fa[x][i] != fa[y][i]){
x = fa[x][i], y = fa[y][i];
}
}
return fa[x][0];
}
inline void dfs2(int u, int pre){
for(int i = first[u]; i; i = e[i].nxt){
int v = e[i].v;
if(v == pre) continue;
dfs2(v, u);
tag[u] += tag[v];
}
}
int main(){
n = read(), m = read();
log2[0] = -1;
for(int i = 1; i <= n; ++i) log2[i] = log2[i >> 1] + 1;
for(int i = 1; i < n; ++i){
int x = read(), y = read();
Add_edge(x, y);
Add_edge(y, x);
}
dfs1(1, 0);
for(int i = 1; i <= m; ++i){
int x = read(), y = read(), p = LCA(x, y);
++tag[x], ++tag[y], tag[p] -= 2;
}
dfs2(1, 0);
for(int i = 2; i <= n; ++i){
if(tag[i] == 0) ans += m;
if(tag[i] == 1) ans += 1;
}
printf("%d\n", ans);
}