HDU4871 Shortest-path tree(树分治)

时间:2021-10-09 21:26:54

好久没做过树分治的题了,对上一次做是在南京赛里跪了一道很裸的树分治题后学的一道,多校的时候没有看这道题,哪怕看了感觉也看不出来是树分治,看出题人给了解题报告里写了树分治就做一下好了。

题意其实就是给你一个图,然后让你转换成一棵树,这棵树满足的是根节点1到其余各点的距离都是图里的最短距离,而且为了保证这棵树的唯一性,路径也必须是最小的。转化成树的方法其实就是跑一次spfa。spfa的时候记下所有到这个的前驱的边,然后这些边集反向的边补上就是构成所有最短路的边。然后在这些边上跑一次dfs,跑前将边按照到达点的序号由小到大排序,注意dfs搜的下一个点的距离必须是最短的才搜,不然的话搜出来的图就是不对的,比划一下题目给的样例就知道了。

至此图的部分转化完了,剩下的就是求一个图里包含了k个点的路径的最长距离,以及有多少条,相似的问题还有有多少条路径的乘积=k,有多少条路径的和>k,有多少条路径的乘积是完全立方数。。。做法就是典型的树分治。

树分治在《挑战程序设计竞赛》这本书上有一个很好的框架可以直接抄,我就直接拿来用了。具体的做法是找出重心,对重心外的部分递归求解,合并的时候枚举到重心的所有路径,枚举的时候可以用一个全局的map ds记录当前到达这个点的所有情况,然后用一个tds去枚举新的部分的路径,然后通过ds和tds更新答案,更新完后将tds的内容加进去ds。下面贴一记代码好了

#pragma warning(disable:4996)
#include <iostream>
#include <cstring>
#include <string>
#include <vector>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <queue>
#include <map>
using namespace std; #define ll long long
#define maxn 31000
#define maxm 61000
#define MP make_pair struct Edge{
int v, w;
Edge(int vi, int wi) :v(vi), w(wi){}
Edge(){}
bool operator < (const Edge & b) const{
return v < b.v;
}
}; vector<Edge> G[maxn];
vector<Edge> E[maxn];
vector<Edge> EE[maxn];
vector<Edge> T[maxn]; int n, m, k; int d[maxn];
int dx[maxn];
bool in[maxn]; void dfs(int u,int dis)
{
in[u] = true; dx[u] = dis;
if (dx[u] != d[u]) puts("fuck");
for (int i = 0; i < EE[u].size(); i++){
int v = EE[u][i].v, w = EE[u][i].w;
if (!in[v]&&w+dis==d[v]) {
T[u].push_back(Edge(v, w));
T[v].push_back(Edge(u, w));
dfs(v, w + dis);
}
}
} void spfa()
{
queue<int> que;
memset(in, 0, sizeof(in));
memset(d, 0x3f, sizeof(d));
d[1] = 0; in[1] = true; que.push(1);
while (!que.empty()){
int u = que.front(); que.pop(); in[u] = false;
for (int i = 0; i < G[u].size(); i++){
int v = G[u][i].v, w = G[u][i].w;
if (d[u] + w < d[v]){
d[v] = d[u] + w;
if (!in[v]) {
in[v] = true; que.push(v);
}
E[v].clear(); E[v].push_back(Edge(u, w));
}
else if (d[u] + w == d[v]){
E[v].push_back(Edge(u, w));
}
}
}
for (int i = 1; i <= n; i++){
for (int j = 0; j < E[i].size(); j++){
EE[E[i][j].v].push_back(Edge(i, E[i][j].w));
EE[i].push_back(E[i][j]);
}
}
for (int i = 1; i <= n; i++) sort(EE[i].begin(), EE[i].end());
memset(in, 0, sizeof(in));
memset(dx, 0x3f, sizeof(dx));
dfs(1,0);
} bool centroid[maxn];
int ssize[maxn]; int compute_subtree_size(int v, int p){
int c = 1;
for (int i = 0; i < T[v].size(); i++){
int w = T[v][i].v;
if (w == p || centroid[w]) continue;
c += compute_subtree_size(w, v);
}
ssize[v] = c;
return c;
} pair<int, int> search_centroid(int v, int p, int t){
pair<int, int> res = MP(INT_MAX, -1);
int s = 1, m = 0;
for (int i = 0; i < T[v].size(); i++){
int w = T[v][i].v;
if (w == p || centroid[w]) continue; res = min(res, search_centroid(w, v, t)); m = max(m, ssize[w]);
s += ssize[w];
}
m = max(m, t - s);
res = min(res, MP(m, v));
return res;
} map<int, pair<int, int> > ds;
map<int, pair<int, int> > tds;
map<int, pair<int, int> >::iterator it;
map<int, pair<int, int> >::iterator itt;
// pass kk points, distant is dis
void enumerate(int v, int p, int kk, int dis, map<int, pair<int, int> > &tds)
{
if (kk > k) return;
it = tds.find(kk);
if (it!=tds.end()){
if (it->second.first == dis) {
it->second.second += 1;
}
else if(it->second.first<dis){
tds.erase(it);
tds.insert(MP(kk, MP(dis, 1)));
}
}
else{
tds.insert(MP(kk, MP(dis, 1)));
}
for (int i = 0; i < T[v].size(); i++){
int w = T[v][i].v;
if (w == p || centroid[w]) continue;
enumerate(w, v, kk + 1, dis + T[v][i].w, tds);
}
} ll ans, num; void solve(int v)
{
compute_subtree_size(v, -1);
int s = search_centroid(v, -1, ssize[v]).second;
centroid[s] = true;
for (int i = 0; i < T[s].size(); i++){
if (centroid[T[s][i].v]) continue;
solve(T[s][i].v);
}
ds.clear();
ds.insert(MP(1, MP(0, 1)));
for (int i = 0; i < T[s].size(); i++){
if (centroid[T[s][i].v]) continue;
tds.clear();
enumerate(T[s][i].v, s, 1, T[s][i].w, tds);
it = tds.begin();
while (it != tds.end()){
int kk = it->first;
if (ds.count(k - kk)){
itt = ds.find(k - kk);
int ldis = it->second.first + itt->second.first;
if (ldis>ans) {
ans = ldis; num = it->second.second*itt->second.second;
}
else if (ldis == ans){
num += it->second.second*itt->second.second;
}
}
++it;
}
it = tds.begin();
while (it != tds.end()){
int kk = it->first + 1;
if (ds.count(kk)){
itt = ds.find(kk);
if (it->second.first > itt->second.first){
ds.erase(itt);
ds.insert(MP(kk, it->second));
}
else if (it->second.first == itt->second.first) itt->second.second += it->second.second;
}
else{
ds.insert(MP(kk, it->second));
}
++it;
}
}
centroid[s] = false;
} int main()
{
int TE; cin >> TE;
while (TE--){
scanf("%d%d%d", &n, &m, &k);
for (int i = 0; i <= n; i++) {
G[i].clear(); E[i].clear(); EE[i].clear(); T[i].clear();
}
int ui, vi, wi;
for (int i = 0; i < m; i++){
scanf("%d%d%d", &ui, &vi, &wi);
G[ui].push_back(Edge(vi, wi));
G[vi].push_back(Edge(ui, wi));
}
spfa();
ans = 0, num = 0;
solve(1);
cout << ans << " " << num << endl;
}
return 0;
}