bzoj 4016 [FJOI2014]最短路径树问题(最短路径树+树分治)

时间:2022-08-07 17:20:34

4016: [FJOI2014]最短路径树问题

Time Limit: 5 Sec  Memory Limit: 512 MB
Submit: 426  Solved: 147
[Submit][Status][Discuss]

Description

给一个包含n个点,m条边的无向连通图。从顶点1出发,往其余所有点分别走一次并返回。
往某一个点走时,选择总长度最短的路径走。若有多条长度最短的路径,则选择经过的顶点序列字典序最小的那条路径(如路径A为1,32,11,路
径B为1,3,2,11,路径B字典序较小。注意是序列的字典序的最小,而非路径中节点编号相连的字符串字典序最小)。到达该点后按原路返回,然后往其他
点走,直到所有点都走过。
可以知道,经过的边会构成一棵最短路径树。请问,在这棵最短路径树上,最长的包含K个点的简单路径长度为多长?长度为该最长长度的不同路径有多少条?
这里的简单路径是指:对于一个点最多只经过一次的路径。不同路径是指路径两端端点至少有一个不同,点A到点B的路径和点B到点A视为同一条路径。

Input

第一行输入三个正整数n,m,K,表示有
n个点m条边,要求的路径需要经过K个点。接下来输入m行,每行三个正整数Ai,Bi,Ci(1<=Ai,Bi<=n,1<=Ci&
lt;=10000),表示Ai和Bi间有一条长度为Ci的边。数据保证输入的是连通的无向图。
 
 

Output

输出一行两个整数,以一个空格隔开,第一个整数表示包含K个点的路径最长为多长,第二个整数表示这样的不同的最长路径有多少条。
 

Sample Input

6 6 4
1 2 1
2 3 1
3 4 1
2 5 1
3 6 1
5 6 1

Sample Output

3 4

HINT

对于所有数据n<=30000,m<=60000,2<=K<=n。数据保证最短路径树上至少存在一条长度为K的路径。

【思路】

先求出满足要求的最短路树来。

分治:求出过根节点的点对数,其它递归处理。

求过根节点的点对:假设现在处理根节点的S子树,用f[i][0]表示前S-1棵子树中与根相距i个节点(不含根)的最长路径,f[i][1]表示方案数,类似的定义tmp为当前S子树的统计结果。一遍bfs构造出tmp,枚举该子树上的结点数更新答案,然后用tmp更新f。

需要注意的有:

累计答案的诸多小细节。

f[][],tmp[][],queue的清零。

【代码】

 #include<cstdio>
#include<vector>
#include<queue>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std; const int N = +;
const int INF = 1e9+1e9; struct Edge {
int v,w;
Edge(int v=,int w=):v(v),w(w) {}
bool operator < (const Edge& rhs) const {
return v<rhs.v;
}
};
queue<int> q;
int n,m,K,root,size,d[N],inq[N];
int ans1,ans2,siz[N],mx[N],vis[N],dep[N],dis[N],fa[N];
vector<Edge> g[N],G[N]; void spfa() {
for(int i=;i<=n;i++) d[i]=INF;
memset(inq,,sizeof(inq));
inq[]=; q.push();
while(!q.empty()) {
int u=q.front(); q.pop(); inq[u]=;
for(int i=;i<g[u].size();i++) {
int v=g[u][i].v,w=g[u][i].w;
if(d[v]>d[u]+w) {
d[v]=d[u]+w;
if(!inq[v]) inq[v]=,q.push(v);
}
}
}
}
void dfs(int u) {
vis[u]=;
for(int i=;i<g[u].size();i++) {
int v=g[u][i].v,w=g[u][i].w;
if(!vis[v] && d[v]==d[u]+w) {
G[u].push_back(Edge(v,w));
G[v].push_back(Edge(u,w));
dfs(v);
}
}
}
void getroot(int u) {
siz[u]=; mx[u]=;
for(int i=;i<G[u].size();i++) {
int v=G[u][i].v;
if(v!=fa[u] && !vis[v]) {
fa[v]=u;
getroot(v);
siz[u]+=siz[v];
if(siz[v]>mx[u]) mx[u]=siz[v];
}
}
mx[u]=max(mx[u],size-siz[u]);
if(mx[u]<mx[root]) root=u;
}
int f[N][],tmp[N][];
void solve(int u,int S){
vis[u]=; f[][]=;
int m=G[u].size();
for(int i=;i<m;i++) {
int v=G[u][i].v;
if(!vis[v]) {
while(!q.empty()) q.pop(); //clear
q.push(v),dep[v]=,dis[v]=G[u][i].w,fa[v]=u;
while(!q.empty()) {
int now=q.front(); q.pop();
int k=dep[now];
if(k>K) break;
if(dis[now]>tmp[k][])
tmp[k][]=dis[now],tmp[k][]=;
if(dis[now]==tmp[k][]) tmp[k][]++;
for(int j=;j<G[now].size();j++) {
int to=G[now][j].v;
if(!vis[to] && to!=fa[now]) {
fa[to]=now;
dep[to]=dep[now]+;
dis[to]=dis[now]+G[now][j].w;
q.push(to);
}
}
}
//for(int j=1;j<=K;j++) printf("(%d,%d) ",tmp[j][0],tmp[j][1]);cout<<endl;
for(int j=;j<=K;j++) { //tmp位于前可以取[1..K]
if(tmp[j][]+f[K-j][]>ans1)
ans1=tmp[j][]+f[K-j][],ans2=;
if(tmp[j][]+f[K-j][]==ans1)
ans2+=tmp[j][]*f[K-j][];
}
for(int j=;j<=K;j++) {
if(tmp[j][]>f[j][]) f[j][]=tmp[j][],f[j][]=;
if(tmp[j][]==f[j][]) f[j][]+=tmp[j][];
tmp[j][]=tmp[j][]=;
}
}
}
//cout<<u<<": "<<ans1<<" "<<ans2<<endl;
for(int j=;j<=K;j++) f[j][]=f[j][]=;
m=G[u].size();
for(int i=;i<m;i++) {
int v=G[u][i].v;
if(!vis[v]) {
size=siz[v];
if(siz[v]>siz[u]) siz[v]=S-siz[v];
root=;
if(size>=K) getroot(v);
solve(root,siz[v]);
}
}
}
void read(int& x) {
char c=getchar(); int f=; x=;
while(!isdigit(c)){if(c=='-') f=-;c=getchar();}
while(isdigit(c)) x=x*+c-'',c=getchar();
x*=f;
}
int main() {
//freopen("in.in","r",stdin);
//freopen("out.out","w",stdout);
read(n),read(m),read(K); K--;
int u,v,w;
for(int i=;i<m;i++) {
read(u),read(v),read(w);
g[u].push_back(Edge(v,w));
g[v].push_back(Edge(u,w));
}
for(int i=;i<=n;i++)
sort(g[i].begin(),g[i].end());
spfa();
memset(vis,,sizeof(vis));
dfs();
size=n; mx[]=INF; root=;
memset(vis,,sizeof(vis));
getroot() , solve(root,size);
printf("%d %d",ans1,ans2);
return ;
}