「CodePlus 2017 11 月赛」大吉大利,晚上吃鸡!(dij+bitset)

时间:2023-06-26 09:20:38

  从S出发跑dij,从T出发跑dij,顺便最短路计数。

  令$F(x)$为$S$到$T$最短路经过$x$的方案数,显然这个是可以用$S$到$x$的方案数乘$T$到$x$的方案数来得到。

  然后第一个条件就变成了满足$F(A)+F(B)=F(T)$,这个只要用map存一下点的状态,每次查$F(T)-F(A)$就可以得到$B$的状态了。

  第二个条件实际上就是$A$无法到达$B$,怎么判断这个呢。

  按最短路正反拓扑排序两次,分别按两种拓扑序做$O(n*m/32)$的传递闭包,然后一个点两种(按拓扑序得到的能到达的点的状态的补集)的交集就是不能到达的点了。

  统计答案的时候找map里$F(T)-F(A)$的状态 & $A$两种(按拓扑序得到的能到达的点的状态的补集)的交集,用bitset::count求出有几个1就好了,记得判一下算重复的情况。

  第一次学会在map里开bitset...

  还有$S$不能到达$T$要输出$n*(n-1)/2$...= =

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<bitset>
#include<queue>
#include<map>
#define ll long long
using namespace std;
const int maxn=;
const ll inf=1e15;
struct tjm{int too, dis, pre;}e[maxn<<];
struct poi{int x; ll dis;};
priority_queue<poi>q;
bool operator<(poi a, poi b){return a.dis>b.dis;}
map<ll,bitset<maxn> >mp;
int n, m, s, t, x, y, z, tot, cnt, top;
int p[maxn], last[maxn], ru[maxn], pos[maxn], st[maxn];
ll ans, dist[][maxn], f[][maxn];
bitset<maxn>g[][maxn];
bool v[maxn];
inline void read(int &k)
{
int f=; k=; char c=getchar();
while(c<'' || c>'') c=='-' && (f=-), c=getchar();
while(c<='' && c>='') k=k*+c-'', c=getchar();
k*=f;
}
inline void add(int x, int y, int z){e[++tot]=(tjm){y, z, last[x]}; last[x]=tot;}
inline void dijkstra(int x, int ty)
{
for(int i=;i<=n;i++) dist[ty][i]=inf;
dist[ty][x]=; f[ty][x]=; q.push((poi){x, });
while(!q.empty())
{
poi now=q.top(); q.pop();
if(now.dis!=dist[ty][now.x]) continue;
for(int i=last[now.x], too;i;i=e[i].pre)
if(dist[ty][too=e[i].too]>dist[ty][now.x]+e[i].dis)
{
f[ty][too]=f[ty][x];
dist[ty][too]=min(inf, dist[ty][now.x]+e[i].dis);
q.push((poi){too, dist[ty][too]});
}
else if(dist[ty][too]==dist[ty][now.x]+e[i].dis) f[ty][too]+=f[ty][x];
}
}
inline bool check(int x, int y, int dis, int ty){return v[y] && dist[ty][x]+dis==dist[ty][y];}
inline void topsort(int ty)
{
memset(ru, , sizeof(ru)); top=;
for(int i=;i<=cnt;i++) for(int j=last[p[i]], too;j;j=e[j].pre)
if(check(p[i], too=e[j].too, e[j].dis, ty)) ru[too]++;
for(int i=;i<=cnt;i++) if(!ru[p[i]]) st[++top]=p[i], pos[p[i]]=top;
for(int i=;i<=top;i++)
for(int j=last[st[i]], too;j;j=e[j].pre)
if(check(st[i], too=e[j].too, e[j].dis, ty))
{
ru[too]--;
if(!ru[too]) st[++top]=too, pos[too]=top;
}
for(int i=;i<=cnt;i++) g[ty][p[i]][p[i]-]=;
for(int i=top;i;i--)
for(int j=last[st[i]], too;j;j=e[j].pre)
if(check(st[i], too=e[j].too, e[j].dis, ty) && pos[st[i]]<pos[too]) g[ty][st[i]]|=g[ty][too];
}
int main()
{
read(n); read(m); read(s); read(t);
for(int i=;i<=m;i++) read(x), read(y), read(z), add(x, y, z), add(y, x, z);
dijkstra(s, ); if(dist[][t]==inf) return printf("%lld\n", 1ll*n*(n-)>>), ;
dijkstra(t, );
for(int i=;i<=n;i++) if(dist[][i]+dist[][i]==dist[][t]) p[++cnt]=i, v[i]=;
for(int i=;i<=cnt;i++) mp[f[][p[i]]*f[][p[i]]]|=<<(p[i]-);
topsort(); topsort();
for(int i=;i<=cnt;i++) ans+=(((mp[f[][t]-f[][p[i]]*f[][p[i]]])>>(i-))&(~g[][p[i]]>>(i-))&(~g[][p[i]]>>(i-))).count();
ll tmp=; for(int i=;i<=cnt;i++) if(f[][p[i]]*f[][p[i]]==f[][t]) tmp++; ans+=tmp*(n-cnt);
printf("%lld\n", ans);
}