CF894E Ralph and Mushrooms(tarjan缩点+拓扑序dp+数学)

时间:2021-01-23 20:53:06

一个强连通内的边显然可以把它的价值完全压榨,其他的边只能过一次,所以我们tarjan求scc,缩成DAG,然后拓扑序dp求最长路。至于怎么算一条边的所有价值,数学搞吧。首先求出 n(n+1)<=w 的最大的n,然后价值就是 nwni=1i(i+1)/2+w ,也就是 nwn(n+1)(n+2)/6+w

#include <bits/stdc++.h>
using namespace std;
#define ll long long
ll const inf=1LL<<60;
#define N 1000010
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return x*f;
}
int n,m,h[N],num=0,dfn[N],low[N],dfnum=0,bel[N],scc=0,h1[N],du[N];
ll sum[N],dp[N],ans=0;
bool inq[N];
struct edge{
int to,next;ll val;
}data[N],data1[N];
stack<int>qq;
void tarjan(int x){
dfn[x]=low[x]=++dfnum;qq.push(x);inq[x]=1;
for(int i=h[x];i;i=data[i].next){
int y=data[i].to;
if(!dfn[y]) tarjan(y),low[x]=min(low[x],low[y]);
else if(inq[y]) low[x]=min(low[x],dfn[y]);
}if(low[x]==dfn[x]){
scc++;while(1){
int y=qq.top();qq.pop();inq[y]=0;bel[y]=scc;
if(x==y) break;
}
}
}
inline ll calc(ll x){
ll k=sqrt(2*x+0.25)-0.5;
return x*k-k*(k+1)*(k+2)/6+x;
}
inline void add(int x,int y,ll val){
data1[++num].to=y;data1[num].next=h1[x];h1[x]=num;data1[num].val=val;
}
int main(){
// freopen("a.in","r",stdin);
n=read();m=read();
while(m--){
int x=read(),y=read(),val=read();
data[++num].to=y;data[num].next=h[x];h[x]=num;data[num].val=val;
}for(int i=1;i<=n;++i) if(!dfn[i]) tarjan(i);int s=read();
for(int x=1;x<=n;++x)
for(int i=h[x];i;i=data[i].next){
int y=data[i].to;if(bel[x]!=bel[y]) continue;
sum[bel[x]]+=calc(data[i].val);
}num=0;
for(int x=1;x<=n;++x)
for(int i=h[x];i;i=data[i].next){
int y=data[i].to;if(bel[x]==bel[y]) continue;
add(bel[x],bel[y],data[i].val+sum[bel[y]]);
}
for(int i=1;i<=scc;++i) dp[i]=-inf;dp[bel[s]]=sum[bel[s]];
for(int x=1;x<=scc;++x)
for(int i=h1[x];i;i=data1[i].next){
int y=data1[i].to;du[y]++;
}queue<int>q;
for(int i=1;i<=scc;++i) if(!du[i]) q.push(i);
while(!q.empty()){
int x=q.front();q.pop();
for(int i=h1[x];i;i=data1[i].next){
int y=data1[i].to;if(--du[y]==0) q.push(y);
if(dp[x]!=-inf) dp[y]=max(dp[y],dp[x]+data1[i].val);
}
}for(int i=1;i<=scc;++i) ans=max(ans,dp[i]);
printf("%lld\n",ans);
return 0;
}