Cyclical Quest CodeForces - 235C (后缀自动机)

时间:2024-08-15 10:06:56

Cyclical Quest

\[Time Limit: 3000 ms\quad Memory Limit: 524288 kB
\]

题意

给出一个字符串为 \(s\) 串,接下来 \(T\) 个查询,枚举给出一个 \(t\) 串,求出 \(t\) 的所有循环子串在 \(s\) 串中出现的次数。

思路

对于每个查询串,因为要所有循环的子串,所以可以先把 \(t\) 串在复制一份到末尾,然后去求 \(LCS\)。

  • 如果在 \(p\) 节点查询到的 \(LCS==tlen\),那么 \(p\) 表示的子串就包含一种满足条件的循环 \(t\) 串。
  • 如果 \(LCS>tlen\),那就说明 \(p\) 包含了一种循环 \(t\) 串,但是匹配长度超过了循环 \(t\) 串,那么这个串的贡献一定在其 \(father\) 上,所以我们令一个临时变量 \(tmp\) 往 \(father\) 跳去寻找这个串的贡献。
  • 最后只要计算出每个节点包含的子串的出现次数,然后把这些满足条件的值加起来就可以了。

这里我们求 \(LCS\) 时直接用中间变量 \(res\)而不用往\(father\) 更新的。比如我现在在节点 \(u\),且 \(father[p]=u\),那么 \(p\) 在 \(u\) 的 \(LCS\) 就算更大,\(u\) 往其 \(father\) 的更新过程其实和 \(p\) 往 \(father\) 的更新过程是一样的,最后会停留在同一个节点,而一个节点的贡献只要算一次,所以 \(u\) 其实可以不往上更新。

还有这题的查询有 \(1e5\),而总的查询长度才 \(1e6\),所以可能出现查询多而每个串短的情况,所以对 \(vis\) 每次去 \(memset\) 是会 \(TLE\) 的...怀疑人生

#include <map>
#include <set>
#include <list>
#include <ctime>
#include <cmath>
#include <stack>
#include <queue>
#include <cfloat>
#include <string>
#include <vector>
#include <cstdio>
#include <bitset>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define lowbit(x) x & (-x)
#define mes(a, b) memset(a, b, sizeof a)
#define fi first
#define se second
#define pii pair<int, int>
#define INOPEN freopen("in.txt", "r", stdin)
#define OUTOPEN freopen("out.txt", "w", stdout) typedef unsigned long long int ull;
typedef long long int ll;
const int maxn = 1e6 + 10;
const int maxm = 1e5 + 10;
const ll mod = 1e9 + 7;
const ll INF = 1e18 + 100;
const int inf = 0x3f3f3f3f;
const double pi = acos(-1.0);
const double eps = 1e-8;
using namespace std; int n, m;
int cas, tol, T; struct Sam {
int node[maxn<<1][27], step[maxn<<1], fa[maxn<<1];
int dp[maxn<<1], tax[maxn<<1], gid[maxn<<1];
int vis[maxn<<1];
int last, sz;
int newnode() {
mes(node[++sz], 0);
dp[sz] = step[sz] = fa[sz] = 0;
return sz;
}
void init() {
sz = 0;
last = newnode();
}
void insert(int k) {
int p = last, np = last = newnode();
dp[np] = 1;
step[np] = step[p]+1;
for(; p&&!node[p][k]; p=fa[p])
node[p][k] = np;
if(p==0) {
fa[np] = 1;
} else {
int q = node[p][k];
if(step[q] == step[p]+1) {
fa[np] = q;
} else {
int nq = newnode();
memcpy(node[nq], node[q], sizeof(node[q]));
fa[nq] = fa[q];
step[nq] = step[p]+1;
fa[np] = fa[q] = nq;
for(; p&&node[p][k]==q; p=fa[p])
node[p][k] = nq;
}
}
}
void handle() {
for(int i=0; i<=sz; i++) tax[i] = 0;
for(int i=1; i<=sz; i++) tax[step[i]]++;
for(int i=1; i<=sz; i++) tax[i] += tax[i-1];
for(int i=1; i<=sz; i++) gid[tax[step[i]]--] = i;
for(int i=sz; i>=1; i--) {
int u = gid[i];
dp[fa[u]] += dp[u];
}
}
int solve(char *s, int len, int id) {
int p = 1, res = 0;
int ans = 0;
for(int i=1; i<=len+len; i++) {
int k = s[i]-'a'+1;
while(p && !node[p][k]) {
p = fa[p];
res = step[p];
}
if(p == 0) {
p = 1;
res = 0;
} else {
p = node[p][k];
res++;
if(res >= len) {
int tmp = p;
while(vis[tmp]!=id &&!(step[fa[tmp]]+1<=len && len<=step[tmp])) {
vis[tmp] = id;
tmp = fa[tmp];
}
if(vis[tmp] != id) {
vis[tmp] = id;
ans += dp[tmp];
}
}
}
}
return ans;
}
} sam;
char s[maxn], t[maxn<<1]; int main() {
scanf("%s", s+1);
int slen = strlen(s+1);
sam.init();
for(int i=1; i<=slen; i++) {
sam.insert(s[i]-'a'+1);
}
sam.handle();
scanf("%d", &T);
for(int tt=1; tt<=T; tt++) {
scanf("%s", t+1);
int tlen = strlen(t+1);
for(int i=1; i<=tlen; i++) {
t[i+tlen] = t[i];
}
t[tlen+tlen+1] = '\0';
int ans = sam.solve(t, tlen, tt);
printf("%d\n", ans);
}
return 0;
}