NOIP 2017 逛公园 (最短路+dp)

时间:2022-12-31 14:57:29

题意:

求出 从 \(1\)\(n\) 的路径长度小于等于最短路 \(+k\) 的路径个数

分析:

首先观察数据特点

\(30\%\) 的数据 \(k=0\) 且没有 \(0\) 边,等同于最短路计数
\(70\%\) 的数据没有 \(0\) 边,那么就是相当于没有 \(0\)
\(100\%\) 的数据 \(k \le 50\)

首先考虑第一档:

显然就是一个最短路径计数,是一个非常经典的问题,就是一个在最短路图上 \(dp\) 的问题,非常基础

再考虑第二档:

没有 \(0\) 环,那么等同于满足 \(dp\) 没有后效性,我们可以把 \(dp\) 状态稍微改一下

\(dis\) 表示最短路

\(dp[i][j]\) 表示到达 \(i\) 当前长度为 \(dis[i]+j\) 的路径个数

\(dp[i][j] \rightarrow dp[to][j + e[i].w + dis[i] - dis[to]]\)

就完成了

那么对于第三档

我们需要处理零环

一种比较简单的方法可以见同学博客 链接

在这里讲一种比较巧妙的方法(代码也比较短

考虑反图处理

把图反过来连建,然后跑 \(dijkstra\) 就可以判断哪些点是不通的

然后处理完这个我们通过 \(dis\) 数组把每条边的权值更改成为冗余代价

然后再对新图跑一遍 \(dijkstra\) 就可以判断出一些点到达终点肯定是会大于 \(k\)

再对新图跑一边拓扑排序求出拓扑序顺便搞掉环

然后这些都搞完了我们就没有环了

之后就可以按照原来的方法 \(dp\)

大概算法就差不多了

实际写代码时为了方便统计答案我们在 \(dp\) 状态后多一维表示最后这一次更新用的是最短路图上的数最短路树上的边还是不是

因为最后一次没有统计,那么我们就要求都是树边

大功告成

洛谷上跑的飞快

#include <map>
#include <set>
#include <ctime>
#include <queue>
#include <stack>
#include <cmath>
#include <vector>
#include <bitset>
#include <cstdio>
#include <cctype>
#include <string>
#include <numeric>
#include <cstring>
#include <cassert>
#include <climits>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std ;
//#define int long long
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
#define loop(s, v, it) for (s::iterator it = v.begin(); it != v.end(); it++)
#define cont(i, x) for (int i = head[x]; i; i = e[i].nxt)
#define clr(a) memset(a, 0, sizeof(a))
#define ass(a, sum) memset(a, sum, sizeof(a))
#define lowbit(x) (x & -x)
#define all(x) x.begin(), x.end()
#define ub upper_bound
#define lb lower_bound
#define pq priority_queue
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define iv inline void
#define enter cout << endl
#define siz(x) ((int)x.size())
#define file(x) freopen(#x".in", "r", stdin),freopen(#x".out", "w", stdout)
typedef long long ll ;
typedef unsigned long long ull ;
typedef pair <int, int> pii ;
typedef vector <int> vi ;
typedef vector <pii> vii ;
typedef queue <int> qi ;
typedef queue <pii> qii ;
typedef set <int> si ;
typedef map <int, int> mii ;
typedef map <string, int> msi ;
const int N = 100010 ;
const int M = 200010 ;
const ll INF = 1ll << 60 ;
const int iinf = 1 << 30 ;
const ll linf = 2e18 ;
const double eps = 1e-7 ;
void print(int x) { cout << x << endl ; exit(0) ; }
void PRINT(string x) { cout << x << endl ; exit(0) ; }
void douout(double x){ printf("%lf\n", x + 0.0000000001) ; }
template <class T> void chmin(T &a, T b) { if (a > b) a = b ; }
template <class T> void chmax(T &a, T b) { if (a < b) a = b ; }

int n, m, k, MOD, top, num, ans, T ;
int head[N], pre[N], vis[N], inT[M], bad[N], t[N], deg[N], dp[53][N][2] ;
ll dis[N] ;

struct node {
    int x, y, z ;
} g[N << 1] ;

struct Edge {
    int to, nxt, w ;
} e[N << 1] ;

struct point {
    int x ; ll dis ;
    bool operator < (const point &a) const {
        return dis > a.dis ;
    }
} ;

void add(int a, int b, int w) {
    e[++top] = (Edge) {b, head[a], w} ;
    head[a] = top ;
}

void dij(int s) {
    priority_queue <point> q ;
    ass(pre, -1) ; clr(vis) ;
    rep(i, 1, n) dis[i] = INF ; dis[s] = 0 ;
    q.push((point) {s, 0}) ;
    while (!q.empty()) {
        int now = q.top().x ; q.pop() ;
        if (vis[now]) continue ;
        vis[now] = 1 ;
        cont(i, now) {
            int to = e[i].to ;
            if (!bad[to] && dis[to] > dis[now] + e[i].w) {
                pre[to] = i ;
                dis[to] = dis[now] + e[i].w ;
                q.push((point) {to, dis[to]}) ;
            }
        }
    }
}

bool topsort() {
    int l = 0, r = 0 ; num = 0 ;
    rep(i, 1, m) if (!bad[g[i].x] && !bad[g[i].y] && g[i].z == 0) deg[g[i].y]++ ;
    rep(i, 1, n)
    if (!bad[i]) {
        if (!deg[i]) t[++r] = i ;
        num++ ;
    }
    while (l < r) {
        int now = t[++l] ;
        cont(i, now) {
            int to = e[i].to ;
            if (e[i].w == 0 && !bad[to]) {
                deg[to]-- ;
                if (!deg[to]) t[++r] = to ;
            }
        }
    }
    return (num == r) ;
}

void init() {
    clr(inT) ; clr(deg) ; clr(bad) ; clr(head) ; top = 0 ;
}

void Main() {
    init() ;
    scanf("%d%d%d%d", &n, &m, &k, &MOD) ;
    rep(i, 1, m) {
        int a, b, c ; scanf("%d%d%d", &a, &b, &c) ;
        g[i] = (node) {a, b, c} ;
        add(b, a, c) ;
    }
    dij(n) ;
    rep(i, 1, n) if (dis[i] == INF) bad[i] = 1 ;
    rep(i, 1, n - 1) if (dis[i] != INF) inT[pre[i]] = 1 ;
    clr(head) ; top = 0 ;
    rep(i, 1, m) {
        g[i].z += dis[g[i].y] - dis[g[i].x] ;
        add(g[i].x, g[i].y, g[i].z) ;
    }
    dij(1) ;
    rep(i, 1, n) if (dis[i] > k) bad[i] = 1 ;
    if (!topsort()) {
        puts("-1") ;
        return ;
    }
    clr(dp) ;
    dp[0][1][0] = 1 ;
    rep(i, 0, k) {
        rep(j, 1, num) {
            int tmp = (dp[i][t[j]][0] + dp[i][t[j]][1]) % MOD ;
            if (tmp) {
                int now = t[j] ;
                cont(l, now) {
                    int to = e[l].to ;
                    if (!bad[to] && i + e[l].w <= k) {
                        (dp[i + e[l].w][to][inT[l]] += tmp) %= MOD ;
                    }
                }
            }
        }
    }
    ans = 0 ;
    rep(i, 0, k)
    rep(j, 1, num)
    (ans += dp[i][t[j]][0]) %= MOD ;
    printf("%d\n", ans) ;
}

signed main(){
    scanf("%d", &T) ;
    while (T--) Main() ;
    return 0 ;
}

/*
写代码时请注意:
    1.ll?数组大小,边界?数据范围?
    2.精度?
    3.特判?
    4.至少做一些
思考提醒:
    1.最大值最小->二分?
    2.可以贪心么?不行dp可以么
    3.可以优化么
    4.维护区间用什么数据结构?
    5.统计方案是用dp?模了么?
    6.逆向思维?
*/