[NOIP2017]逛公园 最短路+拓扑排序+dp

时间:2021-07-30 22:09:56

题目描述

给出一张 $n$ 个点 $m$ 条边的有向图,边权为非负整数。求满足路径长度小于等于 $1$ 到 $n$ 最短路 $+k$ 的 $1$ 到 $n$ 的路径条数模 $p$ ,如果有无数条则输出 $-1$ 。

输入

第一行包含一个整数 $T$ , 代表数据组数。

接下来 $T$ 组数据,对于每组数据: 第一行包含四个整数 $N,M,K,P$ ,每两个整数之间用一个空格隔开。

接下来 $M$ 行,每行三个整数 $a_i,b_i,c_i$ ,代表编号为 $a_i,b_i$ 的点之间有一条权值为 $c_i$ 的有向边,每两个整数之间用一个空格隔开。

输出

输出文件包含 $T$ 行,每行一个整数表示答案。

样例输入

2
5 7 2 10
1 2 1
2 4 0
4 5 2
2 3 2
3 4 1
3 5 2
1 5 3
2 2 0 10
1 2 0
2 1 0

样例输出

3
-1


题解

最短路+拓扑排序+dp

首先使用堆优化Dijkstra或Spfa(不知道本题是否会卡)求出1到所有点的最短路。

由于对于所有边 $(x,y,z)$ 满足 $dis[x]+z\ge dis[y]$ ,因此超过最短路的部分不会减少。

那么我们设 $f[i][j]$ 表示到达点 $i$ 时经过的路径总长度为 $dis[i]+j\ (j \le k)$ 的方案数。那么这相当于一个新的分层图,只会在同层或向上层转移,不会像下层转移。

这就转化为图上求路径条数。首先初始化 $f[1][0]=0 $ ,跑拓扑排序的同时进行转移。

如果一个点被排到了,那么 $f$ 值即为路径条数。

如果一个点没有被排到,则说明有环连接到它,即路径条数为 $\infty$。

因此把所有 $f[n][0...k]$ 统计一下即可。

时间复杂度 $O(T(m\log n+mk))$

考场上一眼看出题解,然而卡了两个小时的常数才勉强卡进去...

考场原代码(去掉了文件操作):

#include <queue>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 100010
#define M 200010
#define R register
#define pos(x , y) (x + (y) * n)
using namespace std;
typedef pair<int , int> pr;
priority_queue<pr> q;
int head[N] , to[M] , len[M] , next[M] , cnt , dis[N] , vis[N];
int ll[N * 51] , rr[N * 51] , tt[M * 51] , que[N * 51] , ind[N * 51] , f[N * 51];
inline void add(int x , int y , int z)
{
	to[++cnt] = y , len[cnt] = z , next[cnt] = head[x] , head[x] = cnt;
}
inline char nc()
{
	static char buf[100000] , *p1 , *p2;
	return p1 == p2 && (p2 = (p1 = buf) + fread(buf , 1 , 100000 , stdin) , p1 == p2) ? EOF : *p1 ++ ;
}
inline int read()
{
	int ret = 0; char ch = nc();
	while(!isdigit(ch)) ch = nc();
	while(isdigit(ch)) ret = ((ret + (ret << 2)) << 1) + (ch ^ '0') , ch = nc();
	return ret;
}
int main()
{
	int T = read();
	while(T -- )
	{
		memset(head , 0 , sizeof(head));
		memset(vis , 0 , sizeof(vis));
		memset(ind , 0 , sizeof(ind));
		memset(f , 0 , sizeof(f));
		cnt = 0;
		int n = read() , m = read() , k = read() , z , ans = 0 , flag = 1;
		R int p = read() , i , j , x , y , l = 1 , r = 0;
		for(i = 1 ; i <= m ; ++i) x = read() , y = read() , z = read() , add(x , y , z);
		memset(dis , 0x3f , sizeof(dis));
		dis[1] = 0 , q.push(pr(0 , 1));
		while(!q.empty())
		{
			x = q.top().second , q.pop();
			if(vis[x]) continue;
			vis[x] = 1;
			for(i = head[x] ; i ; i = next[i])
				if(dis[to[i]] > dis[x] + len[i])
					dis[to[i]] = dis[x] + len[i] , q.push(pr(-dis[to[i]] , to[i]));
		}
		cnt = 0;
		for(x = 1 ; x <= n ; ++x)
		{
			for(j = 0 ; j <= k ; ++j)
			{
				ll[pos(x , j)] = cnt + 1;
				for(i = head[x] ; i ; i = next[i])
					if(j + dis[x] + len[i] - dis[to[i]] <= k)
						++ind[tt[++cnt] = pos(to[i] , j + dis[x] + len[i] - dis[to[i]])];
				rr[pos(x , j)] = cnt;
			}
		}
		f[1] = 1;
		for(x = 1 ; x <= pos(n , k) ; ++x)
			if(!ind[x])
				que[++r] = x;
		while(l <= r)
		{
			x = que[l ++ ];
			for(i = ll[x] ; i <= rr[x] ; ++i)
			{
				y = tt[i];
				f[y] += f[x] , ind[y] -- ;
				if(f[y] >= p) f[y] -= p;
				if(!ind[y]) que[++r] = y;
			}
		}
		for(i = 0 ; i <= k ; ++i)
		{
			if(ind[pos(n , i)]) flag = 0;
			ans = (ans + f[pos(n , i)]) % p;
		}
		if(flag) printf("%d\n" , ans);
		else puts("-1");
	}
	return 0;
}