BZOJ 4326 NOIP2015 运输计划(二分答案 + 树上差分思想)

时间:2021-05-28 21:02:20

题目链接  BZOJ4326

这个程序在洛谷上TLE了……惨遭卡常

在NOIP赛场上估计只能拿到95分吧= =

把边权转化成点权

首先求出每一条路径的长度

考虑二分答案,$check(now)$

对于当前那些长度大于$now$的路径,用差分求出这些路径经过的点的次数

设这些路径条数为l, 长度最大的路径减去$now$的值为$mx$

(点$x1$经过$y1$次,点$x2$经过$y2$次,..., 点$xm$经过$ym$次)

如果我们能找到一条边(一个点),满足$xk = l$ 且 $yk >= mx$ 则$check$成功,否则失败。

#include <bits/stdc++.h>

using namespace std;

#define rep(i, a, b)    for (int i(a); i <= (b); ++i)
#define dec(i, a, b)    for (int i(a); i >= (b); --i)
#define MP		make_pair
#define fi		first
#define se		second

typedef long long LL;
typedef pair <int, int> PII;

const int N = 3e5 + 10;

struct node{
	int x, y, lca, w;
	void scan(){ scanf("%d%d", &x, &y);}
	void print(){ printf("%d %d %d %d\n", x, y, lca, w);}
} path[N];

int father[N], deep[N], sz[N], son[N], top[N];
int a[N], b[N], f[N], g[N];
int n, m, x, y, z, l, r, ans;
vector <PII> v[N];

void dfs(int x, int fa, int dep, int now){
	sz[x] = 1;
	deep[x] = dep;
	father[x] = fa;
	a[x] = now;
	b[x] = b[fa] + a[x];
	int ct = (int)v[x].size();
	rep(i, 0, ct - 1){
		int u = v[x][i].fi;
		if (u == fa) continue;
		dfs(u, x, dep + 1, v[x][i].se);
		sz[x] += sz[u];
		if (sz[son[x]] < sz[u]) son[x] = u;
	}
}

void dfs2(int x, int fa, int tp){
	top[x] = tp;
	if (son[x]) dfs2(son[x], x, tp);
	int ct = (int)v[x].size();
	rep(i, 0, ct - 1){
		int u = v[x][i].fi;
		if (u == son[x] || u == fa) continue;
		dfs2(u, x, u);
	}
}

void calc(int x, int fa){
	int ct = (int)v[x].size();
	rep(i, 0, ct - 1){
		int u = v[x][i].fi;
		if (u == fa) continue;
		calc(u, x);
		f[x] += f[u];
	}
}    

int LCA(int x, int y){
	for (; top[x] ^ top[y]; ){
		if (deep[top[x]] < deep[top[y]]) swap(x, y);
		x = father[top[x]];
	}

	return deep[x] > deep[y] ? y : x;
}

bool check(int now){
	int cnt = 0;
	int mx = 0;
	memset(f, 0, sizeof f);
	rep(i, 1, m) if (path[i].w > now){
		int x = path[i].x, y = path[i].y, w = path[i].w, lca = path[i].lca;
		++cnt;
		if (lca == y) ++f[x], --f[y];
		else ++f[x], ++f[y], f[lca] -= 2;
		mx = max(mx, w - now);
	}

	calc(1, 0);
	rep(i, 1, n) if (a[i] >= mx && f[i] == cnt) return true;
	return false;
}



int main(){

	scanf("%d%d", &n, &m);
	rep(i, 2, n){
		scanf("%d%d%d", &x, &y, &z);
		v[x].push_back(MP(y, z));
		v[y].push_back(MP(x, z));
	}

	dfs(1, 0, 0, 0);
	dfs2(1, 0, 1);

	rep(i, 1, m){
		path[i].scan();
		if (deep[path[i].x] < deep[path[i].y]) swap(path[i].x, path[i].y);
		path[i].lca = LCA(path[i].x, path[i].y);
		path[i].w = b[path[i].x] + b[path[i].y] - 2 * b[path[i].lca];
	}

	l = 0, r = 3e8;
	while (l + 1 < r){
		int mid = l + r >> 1;
		if (check(mid)) r = mid; else l = mid + 1;
	}

	if (check(l)) ans = l; else ans = r;
	printf("%d\n", ans);
	return 0;
}