题目描述
给出一张 $n$ 个点 $m$ 条边的有向图,边权为非负整数。求满足路径长度小于等于 $1$ 到 $n$ 最短路 $+k$ 的 $1$ 到 $n$ 的路径条数模 $p$ ,如果有无数条则输出 $-1$ 。
输入
第一行包含一个整数 $T$ , 代表数据组数。
接下来 $T$ 组数据,对于每组数据: 第一行包含四个整数 $N,M,K,P$ ,每两个整数之间用一个空格隔开。
接下来 $M$ 行,每行三个整数 $a_i,b_i,c_i$ ,代表编号为 $a_i,b_i$ 的点之间有一条权值为 $c_i$ 的有向边,每两个整数之间用一个空格隔开。
输出
输出文件包含 $T$ 行,每行一个整数表示答案。
样例输入
2
5 7 2 10
1 2 1
2 4 0
4 5 2
2 3 2
3 4 1
3 5 2
1 5 3
2 2 0 10
1 2 0
2 1 0
样例输出
3
-1
题解
最短路+拓扑排序+dp
首先使用堆优化Dijkstra或Spfa(不知道本题是否会卡)求出1到所有点的最短路。
由于对于所有边 $(x,y,z)$ 满足 $dis[x]+z\ge dis[y]$ ,因此超过最短路的部分不会减少。
那么我们设 $f[i][j]$ 表示到达点 $i$ 时经过的路径总长度为 $dis[i]+j\ (j \le k)$ 的方案数。那么这相当于一个新的分层图,只会在同层或向上层转移,不会像下层转移。
这就转化为图上求路径条数。首先初始化 $f[1][0]=0 $ ,跑拓扑排序的同时进行转移。
如果一个点被排到了,那么 $f$ 值即为路径条数。
如果一个点没有被排到,则说明有环连接到它,即路径条数为 $\infty$。
因此把所有 $f[n][0...k]$ 统计一下即可。
时间复杂度 $O(T(m\log n+mk))$
考场上一眼看出题解,然而卡了两个小时的常数才勉强卡进去...
考场原代码(去掉了文件操作):
#include <queue> #include <cctype> #include <cstdio> #include <cstring> #include <algorithm> #define N 100010 #define M 200010 #define R register #define pos(x , y) (x + (y) * n) using namespace std; typedef pair<int , int> pr; priority_queue<pr> q; int head[N] , to[M] , len[M] , next[M] , cnt , dis[N] , vis[N]; int ll[N * 51] , rr[N * 51] , tt[M * 51] , que[N * 51] , ind[N * 51] , f[N * 51]; inline void add(int x , int y , int z) { to[++cnt] = y , len[cnt] = z , next[cnt] = head[x] , head[x] = cnt; } inline char nc() { static char buf[100000] , *p1 , *p2; return p1 == p2 && (p2 = (p1 = buf) + fread(buf , 1 , 100000 , stdin) , p1 == p2) ? EOF : *p1 ++ ; } inline int read() { int ret = 0; char ch = nc(); while(!isdigit(ch)) ch = nc(); while(isdigit(ch)) ret = ((ret + (ret << 2)) << 1) + (ch ^ '0') , ch = nc(); return ret; } int main() { int T = read(); while(T -- ) { memset(head , 0 , sizeof(head)); memset(vis , 0 , sizeof(vis)); memset(ind , 0 , sizeof(ind)); memset(f , 0 , sizeof(f)); cnt = 0; int n = read() , m = read() , k = read() , z , ans = 0 , flag = 1; R int p = read() , i , j , x , y , l = 1 , r = 0; for(i = 1 ; i <= m ; ++i) x = read() , y = read() , z = read() , add(x , y , z); memset(dis , 0x3f , sizeof(dis)); dis[1] = 0 , q.push(pr(0 , 1)); while(!q.empty()) { x = q.top().second , q.pop(); if(vis[x]) continue; vis[x] = 1; for(i = head[x] ; i ; i = next[i]) if(dis[to[i]] > dis[x] + len[i]) dis[to[i]] = dis[x] + len[i] , q.push(pr(-dis[to[i]] , to[i])); } cnt = 0; for(x = 1 ; x <= n ; ++x) { for(j = 0 ; j <= k ; ++j) { ll[pos(x , j)] = cnt + 1; for(i = head[x] ; i ; i = next[i]) if(j + dis[x] + len[i] - dis[to[i]] <= k) ++ind[tt[++cnt] = pos(to[i] , j + dis[x] + len[i] - dis[to[i]])]; rr[pos(x , j)] = cnt; } } f[1] = 1; for(x = 1 ; x <= pos(n , k) ; ++x) if(!ind[x]) que[++r] = x; while(l <= r) { x = que[l ++ ]; for(i = ll[x] ; i <= rr[x] ; ++i) { y = tt[i]; f[y] += f[x] , ind[y] -- ; if(f[y] >= p) f[y] -= p; if(!ind[y]) que[++r] = y; } } for(i = 0 ; i <= k ; ++i) { if(ind[pos(n , i)]) flag = 0; ans = (ans + f[pos(n , i)]) % p; } if(flag) printf("%d\n" , ans); else puts("-1"); } return 0; }