【传送门】
我们还是先将一下算法的步骤,待会再解释起来方便一点。
算法步骤
首先我们算出每个子树的\(size\)。
我们就设当前访问的节点
然后我们就得到了当前这个节点的答案是这个树整个的\(size\)的两倍\(-1\),再加上两两子树的\(size\)的积。
也就是$ans=size[u] \times 2 - 1 +\sum size[i] * size[j] $
其中\(i\)和\(j\)是u的儿子。
一下就是核心代码:
register LL ans = sz[p] * 2 - 1;
for (Ri i = head[p] ; i ; i = edge[i].next)
{
if ( edge[i].to == fa[p]) continue;
sz_rest = sz[p] - sz[edge[i].to] - 1 ;
ans = (ans + (sz_rest * sz[edge[i].to] ) ) % Mod;
}
解释一下
首先,我们要明白,如果要以当前这个节点为\(LCA\)的话,那么两个节点必须不能是是这个当前这个要查询的节点的父亲,这个应该没有问题。
那么这两个节点必须是在当前查询节点上,或者是在这个节点的不同子树中。(注:这个一定是要在不同的子树中,不然\(LCA\)就会是在这个子树当中,而不是当前的节点)
那么我们要考虑两种情况
- 如果其中一个点是我们要访问的节点,那么另外一个点就可以是这个这棵子树中任意的点,但是要注意两个细节,这些点对是有顺序的,所以答案需要\(\times 2\),第二个细节是我们要注意根节点是只算一次的,因为相同的点放前放后都是一样的,只算一种。
- 如果两个点是在不同的子树中,那么一棵子树中每一个节点都一定可以和另外一课子树中的任意一点相组成一个合法的点对,那么就根据乘法原理,将这些子树中的节点数相乘在相加就是我们得到的答案了。
那么我们就得到了答案呀,答案就是上文中所提到的$ans=size[u] \times 2 - 1 +\sum size[i] * size[j] $
以下是代码,但是是错的,会TLE一个点。
# include <bits/stdc++.h>
# define Ri register int
# define for1(i,a,b) for (Ri i(a) ; i <= b; i++ )
# define for2(i,a,b) for (Ri i(a) ; i >= b; i-- )
using namespace std;
typedef long long LL;
const int Mod = 1e9 + 7;
const int Maxn = 50005;
struct Edge {
int to , next ;
} edge[Maxn<<1] ;
int head[Maxn<<1];
int sz[Maxn], fa[Maxn];
int root, n ,m , Nedge;
inline int read ()
{
int w = 0,x = 0; char ch = 0;
while ( !isdigit(ch) ) { w |= ch == '-'; ch = getchar() ; }
while ( isdigit(ch) ) { x = (x<<1) + (x<<3) + (ch^48); ch = getchar() ; }
return w ? -x : x ;
}
inline void Add_Edge (int u, int v)
{
edge[Nedge] = (Edge) {v , head[u]};
head[u] = Nedge++;
}
inline void dfs (int u , int fat)
{
fa[u] = fat;
sz[u] = 1;
for (int i = head[u] ; i != -1; i = edge[i].next )
{
int v = edge[i].to;
if ( v == fa[u] ) continue;
dfs (v ,u );
sz[u] += sz[v];
}
}
int main ()
{
memset( head, -1 ,sizeof(head) );
n = read(), root = read(), m = read();
for1 (i , 1 , n-1)
{
int u = read(),v = read();
Add_Edge(u, v); Add_Edge(v, u);
}
dfs (root , -1);
for1 (T ,1 , m )
{
int p = read();
LL ans = sz[p] * 2 - 1;
for (int i = head[p] ; i != -1 ; i = edge[i].next)
{
for (int j = head[p] ; j != -1 ; j = edge[j].next)
{
if (i == j || edge[i].to == fa[p] || edge[j].to == fa[p]) continue;
ans = (ans + sz[edge[i].to] * sz[edge[j].to]) % Mod;
}
}
printf ("%lld\n", ans);
}
return 0 ;
}
用尽了卡常技巧,都没法A掉最后一个点,为什么?
因为最后一个点太大了,如果按照我们的做法,是\(O(n^3+n^2+ m + n)\)的复杂度,这样实在是太不优美了。
我们来想一想优化。
其实我们并不需要枚举两个儿子,因为我们还有一个父亲没有用到,因为任意一个儿子需要\(\times\)不属于这个子树的节点的数量,那么这个数量就是父亲的\(size-\)儿子\(size-1\)。\(-1\)是减去父亲的节点
那么也就是优化成\(O(2\times n^2 + n + m)\)
其实也就是优化了主程序的一小部分。
for1 (T ,1 , m )
{
register int p = read() ;
register int sz_rest;
register LL ans = sz[p] * 2 - 1;
for (Ri i = head[p] ; i != -1 ; i = edge[i].next)
{
if ( edge[i].to == fa[p]) continue;
sz_rest = sz[p] - sz[edge[i].to] - 1 ;
ans = (ans + (sz_rest * sz[edge[i].to] ) ) % Mod;
}
printf ("%lld\n", ans);
}
但是还是做不了,还是会T掉一个点,那么我们就需要更加牛逼的优化了,因为答案不大,我们就用一个桶记录这个答案,这样我们就把超大的\(m\)缩小成了\(n\)了,这样我就刚好过掉了,而且速度老快惹QAQ!
以下是正解代码
# include <cstdio>
# include <cstring>
# include <ctype.h>
# define Ri register int
# define for1(i,a,b) for (Ri i(a) ; i <= b; i++ )
# define for2(i,a,b) for (Ri i(a) ; i >= b; i-- )
using namespace std;
typedef long long LL;
const int Mod = 1e9 + 7;
const int Maxn = 50005;
struct Edge {
int to , next ;
} edge[Maxn<<1] ;
int head[Maxn<<1];
int sz[Maxn], fa[Maxn];
int root, n ,m , Nedge;
LL Ans[Maxn];
inline int read ()
{
int w = 0,x = 0; char ch = 0;
while ( !isdigit(ch) ) { w |= ch == '-'; ch = getchar() ; }
while ( isdigit(ch) ) { x = (x<<1) + (x<<3) + (ch^48); ch = getchar() ; }
return w ? -x : x ;
}
inline void Add_Edge (int u, int v)
{
edge[Nedge] = (Edge) {v,head[u]};
head[u] = Nedge++;
}
inline void dfs (int u , int fat) //计算size,并标记父亲
{
fa[u] = fat; sz[u] = 1;
for (Ri i = head[u] ; i ; i = edge[i].next )
{
int v = edge[i].to;
if ( v == fa[u] ) continue;
dfs (v ,u );
sz[u] += sz[v];//累加size
}
}
int main ()
{
n = read(), root = read(), m = read(); Nedge=1;
for1 (i , 1 , n-1)
{
int u = read(), v = read();
Add_Edge(u, v); Add_Edge(v, u);
}
dfs (root , -1);
for1 (T ,1 , m )
{
register int p = read() ;
if ( Ans[p] != 0 ) {printf("%lld\n", Ans[p] ); continue; } //如果当前的答案已经算过了,直接输出
register int sz_rest; //表示剩下的
register LL ans = sz[p] * 2 - 1; //算第一部分的答案
for (Ri i = head[p] ; i ; i = edge[i].next)
{
if ( edge[i].to == fa[p]) continue;
sz_rest = sz[p] - sz[edge[i].to] - 1 ;//算第二部分的答案。
ans = (ans + (sz_rest * sz[edge[i].to] ) ) % Mod;//累加答案
}
Ans[p] = ans ;
printf ("%lld\n", ans);
}
return 0 ;
}