poj 2449 模板题 A*+spfa
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#define mt(a,b) memset(a,b,sizeof(a))
using namespace std;
const int inf=0x3f3f3f3f;
class AStar { ///A*+spfa求第k短路
typedef int typec;///边权的类型
static const int ME=1e5+;///边的个数
static const int MV=1e3+;///点的个数
struct G {
struct E {
int v,next;
typec w;
} e[ME];
int le,head[MV];
void init(int n) {
le=;
for(int i=; i<=n; i++) head[i]=-;
}
void add(int u,int v,typec w) {
e[le].v=v;
e[le].w=w;
e[le].next=head[u];
head[u]=le++;
}
};
class Spfa { ///单源最短路o(k*ME)k~=2
G g;
int n,inque[MV],i,u,v;
typec dist[MV];
bool used[MV];
queue<int> q;
public:
void init(int tn) { ///传入点的个数
n=tn;
g.init(n);
}
void add(int u,int v,typec w) {
g.add(u,v,w);
}
bool solve(int s) { ///传入起点,存在负环返回false
for(i=; i<=n; i++) {
dist[i]=inf;
used[i]=true;
inque[i]=;
}
used[s]=false;
dist[s]=;
inque[s]++;
while(!q.empty()) q.pop();
q.push(s);
while(!q.empty()) {
u=q.front();
q.pop();
used[u]=true;
for(i=g.head[u]; ~i; i=g.e[i].next) {
v=g.e[i].v;
if(dist[v]>dist[u]+g.e[i].w) {
dist[v]=dist[u]+g.e[i].w;
if(used[v]) {
used[v]=false;
q.push(v);
inque[v]++;
if(inque[v]>n) return false;
}
}
}
}
return true;
}
typec getdist(int id) {
return dist[id];
}
} spfa;
struct Q {
int p;
typec g,h;
friend bool operator <(const Q &a,const Q &b) {
return a.g+a.h>b.g+b.h;
}
} now,pre;
priority_queue<Q> q;
int n,cnt[MV];
G g;
typec ans;
public:
void init(int tn) {
n=tn;
g.init(n);
spfa.init(n);
}
void add(int u,int v,typec w) {
g.add(u,v,w);
spfa.add(v,u,w);
}
bool solve(int s,int t,int k) {
if(s==t) k++;
spfa.solve(t);
while (!q.empty()) q.pop();
for(int i=; i<=n; i++) cnt[i]=;
now.p=s;
now.g=;
now.h=;
q.push(now);
while(!q.empty()) {
pre=q.top();
q.pop();
int u=pre.p;
cnt[u]++;
if(cnt[u]==k&&u==t) {
ans=pre.h+pre.g;
return true;
}
if(cnt[u]>k) continue;
for(int i=g.head[u]; ~i; i=g.e[i].next) {
now.h=pre.h+g.e[i].w;
int v=g.e[i].v;
now.g=spfa.getdist(v);
now.p=v;
q.push(now);
}
}
return false;
}
typec getans() {
return ans;
}
} gg;
int main() {
int n,m,u,v,w,s,t,k;
while(~scanf("%d%d",&n,&m)) {
gg.init(n);
while(m--) {
scanf("%d%d%d",&u,&v,&w);
gg.add(u,v,w);
}
scanf("%d%d%d",&s,&t,&k);
if(!gg.solve(s,t,k)) {
puts("-1");
} else {
printf("%d\n",gg.getans());
}
}
return ;
}