bzoj 3697

时间:2023-03-09 17:42:39
bzoj 3697

题目描述:这里

发现还是点对之间的问题,于是还是上点分

只不过是怎么做的问题

首先对每条边边权给成1和-1(即把原来边权为0的边边权改为-1),那么合法的路径总权值一定为0!

还是将路径分为经过当前根节点和不经过当前根节点的,对不经过当前点的递归处理

那么我们讨论经过当前根节点的路径算法即可

可以发现,如果一条路径经过当前根节点,那么当前根节点可以将这条路径分成两部分,且这两部分权值互为相反数!

(这是很显然的,如果不互为相反数的话那么加起来肯定不是0啊)

接下来我们分析中转站的问题:

其实中转站的含义就是在路径上找到一个点,使得这个点左右两部分权值和均为0

那么这里需要dp处理

我们用两个dp数组处理,分别为$f[i][0/1],g[i][0/1]$,其中g是f的前缀和,f[i][0]表示长度为i的路径,0/1用来记录路径上是否存在距离为i的节点

之所以要记录这一点,是因为起点和终点不能作为休息站,所以当某个权值第一次出现时我们要记录在f[i][0]中,在另半条路径中找到g[-i][1]的点数才能保证休息站的存在

那么答案就由几部分组成:

第一:之前的几棵子树中某一个点到根的路径权值为0,当前子树中某一个点到根的路径权值也为0,那么方案数就是$f[0][0]*g[0][0]$

第二:之前的几棵子树中某一个点到根的路径权值为i,当前子树中某一个点权值为-i,那么答案即为$f[i][1]*g[-i][0]+f[i][0]*g[-i][1]+f[i][1]*g[i][-1]$(注意这里的i不一定是正值!!!)

还是比较好理解,因为不管这个权值之前是否出现过,只需互为相反数然后累计即可,因为左右一定能找到一个合法的位置

但是注意一点:

我们认为根节点的g[0][0]初值为1,这样在累计以根节点为端点的情况的时候是有效的,但对于根节点恰好是端点而且根节点*成为休息站的情况是有问题的,这种情况是不合法的,所以事实上累计两边权值都是0的时候应该用的是$f[0][0]*(g[0][0]-1)$!!

然后算完把f累计到g里做个前缀和就可以了

#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <queue>
#include <stack>
#define ll long long
using namespace std;
const int inf=0x3f3f3f3f;
struct Edge
{
int next;
int to;
int val;
}edge[];
int head[];
int siz[];
int maxp[];
bool vis[];
int dis[];
int dep[];
int has[];
ll g[][];
ll f[][];
ll ans=;
int maxdep=;
int cnt=;
int s,rt;
int n;
void init()
{
memset(head,-,sizeof(head));
cnt=;
}
void add(int l,int r,int w)
{
edge[cnt].next=head[l];
edge[cnt].to=r;
edge[cnt].val=w;
head[l]=cnt++;
}
void get_rt(int x,int fa)
{
siz[x]=,maxp[x]=;
for(int i=head[x];i!=-;i=edge[i].next)
{
int to=edge[i].to;
if(to==fa||vis[to])continue;
get_rt(to,x);
siz[x]+=siz[to];
maxp[x]=max(maxp[x],siz[to]);
}
maxp[x]=max(maxp[x],s-siz[x]);
if(maxp[x]<maxp[rt])rt=x;
}
void get_dis(int x,int fa)
{
maxdep=max(maxdep,dep[x]);
if(has[dis[x]])f[dis[x]][]++;
else f[dis[x]][]++;
has[dis[x]]++;
for(int i=head[x];i!=-;i=edge[i].next)
{
int to=edge[i].to;
if(vis[to]||to==fa)continue;
dep[to]=dep[x]+;
dis[to]=dis[x]+edge[i].val;
get_dis(to,x);
}
has[dis[x]]--;
}
void solve(int x)
{
vis[x]=;
g[n][]=;
int maxx=;
for(int i=head[x];i!=-;i=edge[i].next)
{
int to=edge[i].to;
if(vis[to])continue;
dis[to]=n+edge[i].val,maxdep=,dep[to]=;
get_dis(to,);
maxx=max(maxx,maxdep);
ans+=(g[n][]-)*f[n][];
for(int j=-maxdep;j<=maxdep;j++)ans+=g[n-j][]*f[n+j][]+g[n-j][]*f[n+j][]+g[n-j][]*f[n+j][];
for(int j=n-maxdep;j<=n+maxdep;j++)
{
g[j][]+=f[j][];
g[j][]+=f[j][];
f[j][]=f[j][]=;
}
}
for(int j=n-maxx;j<=n+maxx;j++)g[j][]=g[j][]=;
for(int i=head[x];i!=-;i=edge[i].next)
{
int to=edge[i].to;
if(vis[to])continue;
rt=,maxp[rt]=inf,s=siz[to];
get_rt(to,);
solve(rt);
}
}
int main()
{
scanf("%d",&n);
init();
for(int i=;i<n;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
if(!z)z--;
add(x,y,z),add(y,x,z);
}
maxp[rt]=s=n;
get_rt(,);
solve(rt);
printf("%lld\n",ans);
return ;
}