题意:求两点之间最短路的数目加上比最短路长度大1的路径数目
分析:可以转化为求最短路和次短路的问题,如果次短路比最短路大1,那么结果就是最短路数目加上次短路数目,否则就不加。
求解次短路的过程也是基于Dijkstra的思想。算法中用一个二维数组d[u][tag](tag=0代表最短路,1代表次短路)来记录最短路和次短路的长度,cnt[u][tag]记录二者的数目。所以每个点都有两个访问状态,一个是最短路已经确定,另一个是次短路已经确定,所以vis[u][tag]数组也是二维的。
每次维护邻接点的状态时,有四种可能情况:
1.最短路长度需要更新。此时还需判断次短路是否需要更新,若最短路不存在,则不用更新;若最短路存在,则用最短路覆盖次短路。
2.最短路数目需要更新。传递过来的路径长度与当前最短路相等,那么数量要加上去。
3.次短路长度需要更新。
4.次短路数目需要更新。
且算法可以用优先队列改善效率。
#include<iostream>
#include<stdio.h>
#include<cstring>
#include<queue>
#include<cmath>
#include<algorithm>
#include<vector>
using namespace std;
typedef int LL;
const int maxn =1e3+,maxm = 2e4+;
const LL INF =0x3f3f3f3f;
LL dis[maxn][]; //最短路和次短路长度
struct Edge{
int to,next;
LL val;
};
struct Node{
int u,tag;
bool operator <(const Node &rhs) const {return dis[u][tag]>dis[rhs.u][rhs.tag];}
}; struct SPFA{
int n,m,tot;
Edge edges[maxm];
int head[maxn];
bool vis[maxn][];
int cnt[maxn][]; //最短路和次短路条数 void init(int n){
this->n = n;
this->tot=;
memset(head,-,sizeof(head));
}
void Addedge(int u,int v ,LL dist){
edges[tot].to = v;
edges[tot].val = dist;
edges[tot].next = head[u];
head[u] = tot++;
}
void spfa(int s){
for(int i=;i<=n;++i){
dis[i][]=INF,cnt[i][]=;
dis[i][]=INF,cnt[i][]=;
vis[i][]=vis[i][]=false;
}
dis[s][]=,cnt[s][]=;
priority_queue<Node> Q;
Q.push((Node){s,});
while(!Q.empty()){
Node x =Q.top();Q.pop();
int u = x.u,tag = x.tag;
if(vis[u][tag]) continue;
vis[u][tag] = true;
for(int i=head[u];~i;i=edges[i].next){
int v =edges[i].to,w =edges[i].val;
int tmp = dis[u][tag] + w;
if(dis[v][]>tmp) { //需要更新最短路
if(dis[v][]!=INF){ //将次短路覆盖
dis[v][] = dis[v][];
cnt[v][] = cnt[v][];
Q.push((Node){v,});
}
dis[v][]=tmp;
cnt[v][]=cnt[u][tag];
Q.push((Node){v,});
}
else if(dis[v][]==tmp){ //最短路长度不变,数量增加
cnt[v][]+=cnt[u][tag];
}
else if(dis[v][]>tmp){ //次短路长度改变
dis[v][] = tmp;
cnt[v][] = cnt[u][tag];
Q.push((Node){v,});
}
else if(dis[v][]==tmp){ //次短路长度不变,数量增加
cnt[v][]+=cnt[u][tag];
}
}
}
}
}G; #define LOCAL
int main()
{
#ifdef LOCAL
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
int T,N,M,u,v, s,t;
LL tmp;
scanf("%d",&T);
while(T--){
scanf("%d%d",&N,&M);
G.init(N);
for(int i=;i<=M;++i){
scanf("%d%d%d",&u,&v,&tmp);
G.Addedge(u,v,tmp);
}
scanf("%d%d",&s,&t);
G.spfa(s);
if(dis[t][]==dis[t][]+)
G.cnt[t][]+=G.cnt[t][];
printf("%d\n",G.cnt[t][]);
}
return ;
}