2016vijos 1-1 兔子的字符串(后缀数组 + 二分 + 哈希)

时间:2023-08-28 09:18:20

题意:

给出一个字符串,至多将其划分为n部分,每一部分取出字典序最大的子串ci,最小化 最大的ci

先看一个简化版的问题:

给一个串s,再给一个s的子串t,问能否通过将串划分为k个部分,使t成为划分后的s的字典序最大子串

对于这个问题,从串s的最后面开始,一个字符一个字符的向前推
如果当前[l,r]字典序比t大,那么[l+1,r]就要单独成为一段
比较子串字典序大小用二分+哈希
因为我们是一个字符一个字符的向前推的,所以一定是新的l使当前[l,r]字典序比t大
所以如果此时l==r,那么这个t不可能成为字典序最大子串
如果最后部分数<=k,则这个t可以
那么本题只需要二分子串t就好了
所以现在问题变成如何获取字典序 排名第k的子串
这个可以通过后缀数组的height求出
[sa[1],sa[1]] 是字典序第1小
[sa[1],sa[1]+1]是字典序第2小
……
[sa[1],n]是字典序第n-sa[1]+1 小
[sa[2],sa[2]+height[2]] 是下一个
再下一个是 [sa[2],sa[2]+height[2]+1]
…… 
#include<cstdio>
#include<cstring>
#include<algorithm> using namespace std; #define N 100001 typedef long long LL; const int base=; int n,m;
char s[N]; int a[N];
int v[N];
int p,q=,k;
int sa[][N],rk[][N];
int h[N]; unsigned long long Pow[N],has[N]; pair<int,int>interval[N]; void mul(int *sa,int*rk,int *SA,int *RK)
{
for(int i=;i<=n;++i) v[rk[sa[i]]]=i;
for(int i=n;i;--i) if(sa[i]>k) SA[v[rk[sa[i]-k]]--]=sa[i]-k;
for(int i=n-k+;i<=n;++i) SA[v[rk[i]]--]=i;
for(int i=;i<=n;++i) RK[SA[i]]=RK[SA[i-]]+(rk[SA[i]]!=rk[SA[i-]] || rk[SA[i]+k]!=rk[SA[i-]+k]);
} void presa()
{
for(int i=;i<=n;++i) v[a[i]]++;
for(int i=;i<=;++i) v[i]+=v[i-];
for(int i=;i<=n;++i) sa[p][v[a[i]]--]=i;
for(int i=;i<=n;++i) rk[p][sa[p][i]]=rk[p][sa[p][i-]]+(a[sa[p][i]]!=a[sa[p][i-]]);
for(k=;k<n;k<<=,swap(p,q)) mul(sa[p],rk[p],sa[q],rk[q]);
} void get_height()
{
int j,k=;
for(int i=;i<=n;++i)
{
j=sa[p][rk[p][i]-];
while(a[i+k]==a[j+k]) k++;
h[rk[p][i]]=k;
if(k) k--;
}
} void prehash()
{
Pow[]=;
for(int i=;i<=n;++i) Pow[i]=Pow[i-]*base;
for(int i=;i<=n;++i) has[i]=has[i-]*base+a[i];
} pair<int,int>select(LL k)
{
int now;
LL sum=;
int l,r;
for(int i=;i<=n;++i)
{
now=interval[i].second-interval[i].first+;
if(sum+now>=k)
{
l=sa[p][i];
r=interval[i].first+k-sum-;
return make_pair(l,r);
}
sum+=now;
}
} unsigned long long get_hash(int l,int r)
{
return has[r]-has[l-]*Pow[r-l+];
} int cmp(pair<int,int>x,pair<int,int>y)
{
if(get_hash(x.first,x.second)==get_hash(y.first,y.second)) return ;
int Lx=x.second-x.first+,Ly=y.second-y.first+;
int l=,r=min(Lx,Ly),mid,tmp=;
while(l<=r)
{
mid=l+r>>;
if(get_hash(x.first,x.first+mid-)==get_hash(y.first,y.first+mid-)) tmp=mid,l=mid+;
else r=mid-;
}
if(tmp<min(Lx,Ly)) return s[x.first+tmp]<s[y.first+tmp] ? - : ;
return Lx<Ly ? - : ;
} bool check(pair<int,int>now)
{
int l=n,r=n,sum=;
while(l>=)
if(cmp(make_pair(l,r),now)==)
{
if(l==r) return false;
r=l;
sum++;
if(sum>m) return false;
}
else l--;
return true;
} void solve()
{
LL l=,r=;
for(int i=;i<=n;++i)
{
interval[i].first=sa[p][i]+h[i];
interval[i].second=n;
r+=interval[i].second-interval[i].first+;
}
LL mid,tmp;
pair<int,int>now;
while(l<=r)
{
mid=l+r>>;
now=select(mid);
if(check(now)) tmp=mid,r=mid-;
else l=mid+;
}
now=select(tmp);
l=now.first; r=now.second;
for(int i=l;i<=r;++i) putchar(s[i]);
} int main()
{
freopen("string.in","r",stdin);
freopen("string.out","w",stdout);
scanf("%d",&m);
scanf("%s",s+);
n=strlen(s+);
for(int i=;i<=n;++i) a[i]=s[i]-'a'+;
presa();
get_height();
prehash();
solve();
}