描述:
给出一个单词,在单词中插入若干字符使其为回文串,求回文串的个数(|s|<=200,n<=10^9)
这道题超神奇,不可多得的一道好题
首先可以搞出一个dp[l][r][i]表示回文串左边i位匹配到第l位,右边i位匹配到第r位的状态数,可以发现可以用矩阵乘法优化(某人说看到n这么大就一定是矩阵乘法了= =)
但这样一共有|s|^2个节点,时间复杂度无法承受
我们先把状态树画出来:例如add
可以发现是个DAG
我们考虑把单独的每条链拿出来求解,那么最多会有|s|条不同的链,链长最多为|s|,时间复杂度为O(|s|^4log n)还是得跪
好像没什么思路了对吧= =(我第一步转化就没想到了= =)
我们考虑记有24个自环的为n24,25个自环的为n25,可以发现n24+n25*2=|s|或|s|+1也就是说对于一个确定的n24,一定有一个确定的n25
那么这样构图:
可以发现所有状况都被包括进来了!!!
那么一共有2|s|个节点,时间复杂度降了一个|s|,看上去好像还是不行
压常数= =
可以发现这个是棵树,也就是说如果按拓扑序编号的话,到时的矩阵左下角将是什么都没有的
那么就直接for i = 1 to n j = i to n k=i to j 就行了 = =
总结下吧
这道题为何神奇呢
首先它把一个DAG的图拆成了若干条相似的链
然后它又把这些链和成了一个更和谐的图
最后再观察题目性质得到一个比较神奇的优化方法
这给了我们什么启迪呢= =
首先遇到某些DAG我们可以考虑拆成若干条相似的链
遇到某些链我们可以考虑把他们合成一个图
最重要的是,还是得参透题目的性质
这道题基本都是依靠题目的性质到达下一步的,只有真正读懂读透这道题,我们才能想出更好的解法
CODE:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
#define maxn 410
#define mod 10007
typedef int ll;
struct marix{
int r,c;ll a[maxn][maxn];
inline void init(int x){r=c=x;for (int i=;i<=x;i++) a[i][i]=;}
}x,y;
inline void muti(marix &ans,const marix x,const marix y){
ans.r=ans.c=x.r;
for (int i=;i<=x.r;i++)
for (int j=i;j<=y.c;j++) {
int tmp=;
for (int k=i;k<=j;k++)
(tmp+=x.a[i][k]*y.a[k][j])%=mod;
ans.a[i][j]=tmp;
}
}
inline void power(marix &ans,marix x,int y) {
ans.init(x.r);
for (;y;y>>=) {
if (y&) muti(ans,ans,x);
muti(x,x,x);
}
}
ll f[][][];
char s[maxn];
inline ll calc(int l,int r,int x) {
ll &u=f[x][l][r];
if (u!=-) return u;
u=;
if (l==r) return u=x==;
if (s[l]==s[r]) {
if (l+==r) return u=x==;
return u=calc(l+,r-,x);
}
if (x>) return u=(calc(l+,r,x-)+calc(l,r-,x-))%mod;
return u;
}
int main(){
int n,m;
memset(f,-,sizeof(f));
scanf("%s",s+);
scanf("%d",&n);
m=strlen(s+);
n+=m;
int l=(n+)/,n24=m-,n25=(m+)/,n26=n25;
x.r=x.c=n24+n25+n26;
for (int i=;i<=n24;i++) x.a[i][i]=,x.a[i][i+]=;
for (int i=n24+;i<=n25+n24;i++) x.a[i][i]=,x.a[i][i+n25]=;
for (int i=n24+n25+;i<=n25+n24+n26;i++) x.a[i][i]=;
for (int i=n24+;i<n25+n24;i++) x.a[i][i+]=;
marix y;
power(y,x,l-);
muti(x,y,x);
ll ans;
for (int i=;i<=n24;i++) {
int j=(m-i+)/,k=l-i-j;
if (k<) continue;
ll sum=calc(,m,i);
(ans+=sum*x.a[n24-i+][n24+j+n25]%mod)%=mod;
if ((n&)&&(m-i&^))
(ans=ans-sum*y.a[n24-i+][n24+j]%mod+mod)%=mod;
}
printf("%d\n",ans);
return ;
}