poj 3463/hdu 1688 求次短路和最短路个数

时间:2022-12-18 14:55:07

http://poj.org/problem?id=3463

http://acm.hdu.edu.cn/showproblem.php?pid=1688

求出最短路的条数比最短路大1的次短路的条数和,基本和上题一样,最后需判断是否满足dist[t][0]+1==dist[t][1];

cnt[i][0]表示到达点i最短的路有多少条,cnt[i][1]表示次短的条数

dist[i][0]表示到达点i最短路的长度,dist[i][1]表示次短路的长度



用v去松驰u时有四种情况 (设当前dist[v][cas])

情况1:dist[u][fag]+w(v,u)<dist[v][0],找到一个更短的距离,则把原来最短的距离作为次短的距离,同时更新最短的.把(v,0)和(v,1)放入队列


情况2:dist[u][flag]+w(v,u)==dist[v][0],找到一条新的相同距离的最短路,则cnt[v][0]+=cnt[u][flag],不入队

情况3:dist[u][flag]+w(v,u)<dist[v][1],情况4:dist[u][flag]+w(v,u)==dist[v][1] 对次短边的操作参照以上的完成即可。

注意:注意特判dist[v][0] != inf,放置插入了不该插入的(v,1)

优先队列cmp结构体的书写

#pragma comment(linker, "/STACK:36777216")
#pragma GCC optimize ("O2")
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <string>
#include <queue>
#include <map>
#include <iostream>
#include <algorithm>
using namespace std;
#define RD(x) scanf("%d",&x)
#define RD2(x,y) scanf("%d%d",&x,&y)
#define RD3(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define clr0(x) memset(x,0,sizeof(x))
#define clr1(x) memset(x,-1,sizeof(x))
#define eps 1e-9
const double pi = acos(-1.0);
typedef long long LL;
typedef unsigned long long ULL;
const int modo = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const int inf = 0x3fffffff;
const LL _inf = 1e18;
const int maxn = 1005,maxm = 10005;
struct edge{
int v,w,next;
edge(){};
edge(int vv,int ww,int nnext):v(vv),w(ww),next(nnext){};
}e[maxm<<1];
int head[maxn],inq[maxn][2],dist[maxn][2],cnt[maxn][2];//0最短1次短
int n,m,ecnt;
void init()
{
clr1(head);
ecnt = 0;
for(int i = 1;i <= n;++i)
dist[i][0] = dist[i][1] = inf;
//fill(dist,dist+maxn*2,inf);
clr0(inq),clr0(cnt);
}
void add(int u,int v,int w)
{
e[ecnt] = edge(v,w,head[u]);
head[u] = ecnt++;
// e[ecnt] = edge(u,w,head[v]);
// head[v] = ecnt++;
}
typedef pair<int,int> p2;
struct cmp {
bool operator() (const p2 &a, const p2 &b)
{
return dist[a.first][a.second] > dist[b.first][b.second];
}
};
void spfa(int src,int dst)
{
priority_queue<p2 , vector<p2> , cmp> q;
q.push(make_pair(src,0));
dist[src][0] = 0,cnt[src][0] = 1;
while(!q.empty()){
int u = q.top().first,flag = q.top().second;
q.pop();
if(inq[u][flag]) continue;
inq[u][flag] = 1;
for(int i = head[u];i != -1;i = e[i].next){
int v = e[i].v,w = e[i].w;
if(!inq[v][0] && dist[v][0] > dist[u][flag] + e[i].w){
if(dist[v][0] != inf){
dist[v][1] = dist[v][0];
cnt[v][1] = cnt[v][0];
q.push(make_pair(v,1));
}
dist[v][0] = dist[u][flag] + e[i].w;
cnt[v][0] = cnt[u][flag]; q.push(make_pair(v,0));
}else if(!inq[v][0] && dist[v][0] == dist[u][flag] + e[i].w){
cnt[v][0] += cnt[u][flag];
}else if(!inq[v][1] && dist[v][1] > dist[u][flag] + e[i].w){
dist[v][1] = dist[u][flag] + e[i].w;
cnt[v][1] = cnt[u][flag]; q.push(make_pair(v,1));
}else if(!inq[v][1] && dist[v][1] == dist[u][flag] + e[i].w){
cnt[v][1] += cnt[u][flag];
}
}
}
//printf("%d %d\n",dist[dst][1],dist[dst][0]);
if(dist[dst][1] == dist[dst][0] + 1)
printf("%d\n",cnt[dst][1] + cnt[dst][0]);
else printf("%d\n",cnt[dst][0]);
} int main(){
int u,v,w,_,s,t;
RD(_);
while(_--){
RD2(n,m);
init();
while(m--){
RD3(u,v,w);
add(u,v,w);
}
RD2(s,t);
spfa(s,t);
}
return 0;
}