2019牛客多校第四场 I题 后缀自动机_后缀数组_求两个串de公共子串的种类数

时间:2021-10-15 18:45:49

@

求若干个串的公共子串个数相关变形题

  • 牛客这题题意大概是求一个长度为\(2e5\)的字符串有多少个不同子串,若\(s==t\)或\(s==rev(t)\)则认为子串\(s,t\)相同。我们知道回文串肯定和他的反串相同。
  • 链接:传送门

做法1:

  • \(yx\)大佬秒出思路%%,对\(s\)串建后缀自动机,可以得到串\(s\)本质不同的子串的个数\(all\),然后只要能减去有多少个串\(x\)的\(rev(x)\)同时也出现了即可。
  • 考虑先求出\(s\)和\(rev(s)\)的本质不同的公共子串的数量\(res\),串\(s\)本质不同的回文串数量为\(q\),显然\(res-q\)肯定是\(2\)的倍数。求回文串数量是个板子题:here
  • 因为\(s\)和\(rev(s)\)本质不同的公共子串除了回文串,就只有非回文串且\(x=rev(x)\)的串了。又因为\(x\)和\(rev(x)\)只能算一次贡献,所以最后答案就是\(all-\frac {res-q} 2\)。
  • 所以我们现在只要能求出串\(t\)与串\(s\)的公共子串种类数量即可。(还有一种题是求长度至少为k的公共子串数量

做法2:

广义后缀自动机直接求即可。

用普通后缀自动机也有更简单做法,我在第一个做法下面有讲解。

做法3:

后缀数组


对一个串建后缀自动机,另一个串在上面跑同时计数

  • 构建好\(s\)串的后缀自动机后,从根节点开始用\(t\)串在上面匹配,记录一下已经匹配的\(lcs\)长度\(LEN\)。若\(u\)节点有\(t[i]\)这个后继,则\(u\)跳到\(nex[u][t[i]-'a'],LEN++\);如果没有这个后继,就从\(u\)开始沿着后缀连接树向上走直到碰到一个节点\(x\)有\(t[i]\)这个后继或者到了根节点\(x\),则\(u = nex[x][t[i]-'a'],LEN=len[x]+1\)。
  • 算贡献就是我当前在\(u\)节点,\(lcs\)长度为\(len\),那么\(LEN-len[link[u]]\)就是符合条件的子串。但是这不完全,就是如果\(len[link[u]]\)也大于\(0\)的话,那么他的父亲状态\(link[u]\)是有符合条件的子串,而且符合条件的子串的数量是固定的:\(len[u]-len[link[u]]\)。
  • 听说如果你每次走后缀连接树算完所有贡献的话会\(tle\),一个优化就是匹配结束后,逆拓扑排序更新父亲结点的出现次数。像线段树一样用一个\(lazy\)标记记录它是否需要更新,要记得把\(lazy\)标记向父亲上传。
  • 但是这样不够,因为还有一部分贡献没有计算,你可能多次匹配到自动机上的一个节点,我们需要记录一下匹配到每个节点的最长\(lcs\)长度即\(vis[u]\),若\(vis[u]\)等于\(0\),则贡献如上,反之贡献为\(LEN-vis[u]\),最后更新\(vis[u]\)为\(LEN\)。
  • 本题结束。

其实还有一个更简单的方法,把串\(s\)和串\(rev(s)\)用一个没有出现过得字符拼接起来,求出新字符串的本质不同的子串个数\(x\),我们知道包含那个未出现过字符的子串数量为\(y = (Len+1)\times (Len+1)\),(注意串\((ba)\)和串\((ab)\)只能计一个贡献)然后在求出\(s\)本质不同的回文串个数\(p\),答案就是\(\frac{x-y+p}2\)

#pragma comment(linker, "/STACK:102400000,102400000")
#include<bits/stdc++.h>
#define fi first
#define se second
#define endl '\n'
#define o2(x) (x)*(x)
#define BASE_MAX 30
#define mk make_pair
#define eb emplace_back
#define all(x) (x).begin(), (x).end()
#define clr(a, b) memset((a),(b),sizeof((a)))
#define iis std::ios::sync_with_stdio(false); cin.tie(0)
#define my_unique(x) sort(all(x)),x.erase(unique(all(x)),x.end())
using namespace std;
#pragma optimize("-O3")
typedef long long LL;
typedef pair<int, int> pii;
inline LL read() {
LL x = 0;int f = 0;
char ch = getchar();
while (ch < '0' || ch > '9') f |= (ch == '-'), ch = getchar();
while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x = f ? -x : x;
}
inline void write(LL x) {
if (x == 0) {putchar('0'), putchar('\n');return;}
if (x < 0) {putchar('-');x = -x;}
static char s[23];
int l = 0;
while (x != 0)s[l++] = x % 10 + 48, x /= 10;
while (l)putchar(s[--l]);
putchar('\n');
}
int lowbit(int x) { return x & (-x); }
template<class T>T big(const T &a1, const T &a2) { return a1 > a2 ? a1 : a2; }
template<typename T, typename ...R>T big(const T &f, const R &...r) { return big(f, big(r...)); }
template<class T>T sml(const T &a1, const T &a2) { return a1 < a2 ? a1 : a2; }
template<typename T, typename ...R>T sml(const T &f, const R &...r) { return sml(f, sml(r...)); }
void debug_out() { cerr << '\n'; }
template<typename T, typename ...R>void debug_out(const T &f, const R &...r) {cerr << f << " ";debug_out(r...);}
#define debug(...) cerr << "[" << #__VA_ARGS__ << "]: ", debug_out(__VA_ARGS__); #define print(x) write(x); const LL INFLL = 0x3f3f3f3f3f3f3f3fLL;
const int HMOD[] = {1000000009, 1004535809};
const LL BASE[] = {1572872831, 1971536491};
const int mod = 998244353;
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const int MXN = 1e6 + 7; int n;
char s[MXN], t[MXN];
LL all, ANS;
int vis[MXN], lazy[MXN];
struct Palindromic_Tree {
static const int MAXN = 600005 ;
static const int CHAR_N = 26 ;
int next[MAXN][CHAR_N];//next指针,next指针和字典树类似,指向的串为当前串两端加上同一个字符构成
int fail[MAXN];//fail指针,失配后跳转到fail指针指向的节点
int cnt[MAXN];
int num[MAXN];
int len[MAXN];//len[i]表示节点i表示的回文串的长度
int S[MAXN];//存放添加的字符
int last;//指向上一个字符所在的节点,方便下一次add
int n;//字符数组指针
int p;//节点指针
int pos[MAXN];
int newnode(int l) {//新建节点
for (int i = 0; i < CHAR_N; ++i) next[p][i] = 0;
cnt[p] = 0;
num[p] = 0;
len[p] = l;
return p++;
}
void init() {//初始化
p = 0;
newnode(0);
newnode(-1);
last = 0;
n = 0;
S[n] = -1;//开头放一个字符集中没有的字符,减少特判
fail[0] = 1;
}
int get_fail(int x) {//和KMP一样,失配后找一个尽量最长的
while (S[n - len[x] - 1] != S[n]) x = fail[x];
return x;
}
void add(int c, int id) {
c -= 'a';
S[++n] = c;
int cur = get_fail(last);//通过上一个回文串找这个回文串的匹配位置
if (!next[cur][c]) {//如果这个回文串没有出现过,说明出现了一个新的本质不同的回文串
int now = newnode(len[cur] + 2);//新建节点
fail[now] = next[get_fail(fail[cur])][c];//和AC自动机一样建立fail指针,以便失配后跳转
next[cur][c] = now;
num[now] = num[fail[now]] + 1;
}
last = next[cur][c];
cnt[last] ++;
pos[last] = id;
}
void count() {
for (int i = p - 1; i >= 0; --i) cnt[fail[i]] += cnt[i];
//父亲累加儿子的cnt,因为如果fail[v]=u,则u一定是v的子回文串!
}
} pt;
struct Suffix_Automaton {
static const int maxn = 1e6 + 105;
static const int MAXN = 1e6 + 5;
//basic
// map<char,int> nex[maxn * 2];
int nex[maxn*2][26];
int link[maxn * 2], len[maxn * 2];
int last, cnt;
LL tot_c;//不同串的个数
//extension
int cntA[MAXN * 2], A[MAXN * 2];/*辅助拓扑更新*/
int nums[MAXN * 2];/*每个节点代表的所有串的出现次数*/
void clear() {
tot_c = 0;
last = cnt = 1;
link[1] = len[1] = 0;
memset(nex[1], 0, sizeof(nex[1]));
}
void init_str(char *s) {
while (*s) {
add(*s - 'a');
++ s;
}
}
void add(int c) {
int p = last;
int np = ++cnt;
// nex[cnt].clear();
memset(nex[cnt], 0, sizeof(nex[cnt]));
len[np] = len[p] + 1;
last = np;
while (p && !nex[p][c])nex[p][c] = np, p = link[p];
if (!p)link[np] = 1, tot_c += len[np] - len[link[np]];
else {
int q = nex[p][c];
if (len[q] == len[p] + 1)link[np] = q, tot_c += len[np] - len[link[np]];
else {
int nq = ++cnt;
len[nq] = len[p] + 1;
// nex[nq] = nex[q];
memcpy(nex[nq], nex[q], sizeof(nex[q]));
link[nq] = link[q];
link[np] = link[q] = nq;
tot_c += len[np] - len[link[np]];
while (nex[p][c] == q)nex[p][c] = nq, p = link[p];
}
}
}
void build(int n) {
memset(cntA, 0, sizeof cntA);
memset(nums, 0, sizeof nums);
for (int i = 1; i <= cnt; i++)cntA[len[i]]++;
for (int i = 1; i <= n; i++)cntA[i] += cntA[i - 1];
for (int i = cnt; i >= 1; i--)A[cntA[len[i]]--] = i;
/*更行主串节点*/
int temps = 1;
for (int i = 0; i < n; i++) {
nums[temps = nex[temps][s[i] - 'a']] = 1;
}
for (int i = cnt, x; i >= 1; i--) {
x = A[i];
nums[link[x]] += nums[x];
}
}
void query() {
int u = 1, LEN = 0;
for(int i = 0; i < n; ++i) {
if(nex[u][t[i]-'a']) {
u = nex[u][t[i]-'a'];
++ LEN;
}else {
while (u && nex[u][t[i] - 'a'] == 0) u = link[u];
if (u == 0) u = 1, LEN = 0;
else {
LEN = len[u] + 1;
u = nex[u][t[i] - 'a'];
}
}
if(vis[u] == 0) {
ANS += 1 * (LEN - len[link[u]]);
// debug(i, t[i], LEN - len[link[u]])
if (len[link[u]]) lazy[link[u]] = 1;
vis[u] = LEN;
}else if(LEN > vis[u]) {
ANS += 1 * (LEN - vis[u]);
// debug(i, t[i], LEN - vis[u])
vis[u] = LEN;
}
}
for(int i = cnt, x; i >= 1; --i) {
x = A[i];
if(vis[x] == 0 && len[x] && lazy[x]) {
ANS += len[x] - len[link[x]];
vis[x] = len[x];
if(len[link[x]]) lazy[link[x]] = 1;
}else if(lazy[x] && vis[x] < len[x]) {
ANS += len[x] - vis[x];
vis[x] = len[x];
if(len[link[x]]) lazy[link[x]] = 1;
}
if(len[link[x]]) lazy[link[x]] = 1;
}
}
void DEBUG() {
for (int i = cnt; i >= 1; i--) {
printf("nums[%d]=%d numt[%d]=%d len[%d]=%d link[%d]=%d\n", i, nums[i], i, nums[i], i, len[i], i, link[i]);
}
}
} sam; int main() {
#ifndef ONLINE_JUDGE
freopen("/home/cwolf9/CLionProjects/ccc/in.txt", "r", stdin);
//freopen("/home/cwolf9/CLionProjects/ccc/out.txt", "w", stdout);
#endif
// int tim = read();
scanf("%s", s);
memcpy(t, s, sizeof(s));
n = strlen(s);
reverse(t, t + n);
sam.clear();
sam.init_str(s);
all = sam.tot_c;
sam.build(n);
sam.query();
pt.init();
for(int i = 0; i < n; ++i) pt.add(s[i], i);
int hui = pt.p - 2;
debug(n, hui, all, ANS)
printf("%lld\n", all - (ANS - hui) / 2);
#ifndef ONLINE_JUDGE
cout << "time cost:" << clock() << "ms" << endl;
#endif
return 0;
}

广义后缀自动机

  • 直接离线构建广义后缀自动机(插入函数和普通后缀自动机一模一样),先插入\(s\)串,置\(last=1\),再插入\(rev(s)\),然后对这个后缀自动机求出本质不同的子串个数\(all\)(回文串只计算一次贡献,其他串计算了两次,因为\(x=rev(x)\)),设\(p\)表示\(s\)串本质不同的回文串个数,最后答案即为\(\frac{all+p}2\)

后缀数组


其他:POJ 3415 求两个串长度至少为k的公共子串数量

本题不需要去重。可后缀数组也可后缀自动机写。

后缀自动机

解法和牛客那题基本一样,甚至更简单,因为本题不需要去重,是算总数。

不需要记录每个节点被匹配到的\(lcs\)长度,因此当前节点每次被匹配到的贡献都是\(LEN-max(len[link[u]],k-1)\)。

因为是算所有子串的数量,只需要用\(lazy[]\)标记表示这个节点被匹配到的次数,最后逆拓扑序向上传\(lazy[]\)标记即可。

后缀数组

按套路,把\(s,t\)拼成一个串,两遍单调栈,分别算\(t\)串对\(s\)串的贡献和\(s\)串对\(t\)串的贡献

#pragma comment(linker, "/STACK:102400000,102400000")
//#include<bits/stdc++.h>
#include<cstdio>
#include<cstring>
#include<string>
#include<vector>
#include<stack>
#include<map>
#include<iostream>
#include<assert.h>
#define fi first
#define se second
#define endl '\n'
#define o2(x) (x)*(x)
#define BASE_MAX 30
#define mk make_pair
#define eb emplace_back
#define all(x) (x).begin(), (x).end()
#define clr(a, b) memset((a),(b),sizeof((a)))
#define iis std::ios::sync_with_stdio(false); cin.tie(0)
#define my_unique(x) sort(all(x)),x.erase(unique(all(x)),x.end())
using namespace std;
#pragma optimize("-O3")
typedef long long LL;
typedef pair<int, int> pii;
inline LL read() {
LL x = 0;int f = 0;
char ch = getchar();
while (ch < '0' || ch > '9') f |= (ch == '-'), ch = getchar();
while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x = f ? -x : x;
}
inline void write(LL x) {
if (x == 0) {putchar('0'), putchar('\n');return;}
if (x < 0) {putchar('-');x = -x;}
static char s[23];
int l = 0;
while (x != 0)s[l++] = x % 10 + 48, x /= 10;
while (l)putchar(s[--l]);
putchar('\n');
}
int lowbit(int x) { return x & (-x); }
template<class T>T big(const T &a1, const T &a2) { return a1 > a2 ? a1 : a2; }
//template<typename T, typename ...R>T big(const T &f, const R &...r) { return big(f, big(r...)); }
//template<class T>T sml(const T &a1, const T &a2) { return a1 < a2 ? a1 : a2; }
//template<typename T, typename ...R>T sml(const T &f, const R &...r) { return sml(f, sml(r...)); }
//void debug_out() { cerr << '\n'; }
//template<typename T, typename ...R>void debug_out(const T &f, const R &...r) {cerr << f << " ";debug_out(r...);}
//#define debug(...) cerr << "[" << #__VA_ARGS__ << "]: ", debug_out(__VA_ARGS__); #define print(x) write(x); const LL INFLL = 0x3f3f3f3f3f3f3f3fLL;
const int HMOD[] = {1000000009, 1004535809};
const LL BASE[] = {1572872831, 1971536491};
const int mod = 998244353;
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const int MXN = 2e5 + 7; int n, m, k;
LL ANS;
char s[MXN], t[MXN];
LL lazy[MXN];
struct Suffix_Automaton {
static const int maxn = 2e5 + 105;
static const int MAXN = 2e5 + 5;
//basic
// map<char,int> nex[maxn * 2];
int nex[maxn][58];
int link[maxn * 2], len[maxn * 2];
int last, cnt;
LL tot_c;//不同串的个数
//extension
int cntA[MAXN * 2], A[MAXN * 2];/*辅助拓扑更新*/
int nums[MAXN * 2];/*每个节点代表的所有串的出现次数*/
void clear() {
tot_c = 0;
last = cnt = 1;
link[1] = len[1] = 0;
// nex[1].clear();
memset(nex[1], 0, sizeof(nex[1]));
}
void init_str(char *s) {
while (*s) {
add(*s - 'A');
++ s;
}
}
void add(int c) {
int p = last;
int np = ++cnt;
// nex[cnt].clear();
memset(nex[cnt], 0, sizeof(nex[cnt]));
len[np] = len[p] + 1;
last = np;
while (p && !nex[p][c])nex[p][c] = np, p = link[p];
if (!p)link[np] = 1, tot_c += len[np] - len[link[np]];
else {
int q = nex[p][c];
if (len[q] == len[p] + 1)link[np] = q, tot_c += len[np] - len[link[np]];
else {
int nq = ++cnt;
len[nq] = len[p] + 1;
// nex[nq] = nex[q];
memcpy(nex[nq], nex[q], sizeof(nex[q]));
link[nq] = link[q];
link[np] = link[q] = nq;
tot_c += len[np] - len[link[np]];
while (nex[p][c] == q)nex[p][c] = nq, p = link[p];
}
}
}
void build(int n) {
for(int i = 0; i <= cnt; ++i) nums[i] = cntA[i] = 0;
for (int i = 1; i <= cnt; i++) cntA[len[i]]++;
for (int i = 1; i <= n; i++)cntA[i] += cntA[i - 1];
for (int i = cnt; i >= 1; i--)A[cntA[len[i]]--] = i;
/*更行主串节点*/
int temps = 1;
for (int i = 0; i < n; i++) {
nums[temps = nex[temps][s[i] - 'A']] = 1;
}
for (int i = cnt, x; i >= 1; i--) {
x = A[i];
nums[link[x]] += nums[x];
}
}
void query() {
int u = 1, LEN = 0;
for(int i = 0; i < m; ++i) {
if(nex[u][t[i]-'A']) {
u = nex[u][t[i]-'A'];
++ LEN;
}else {
while (u && nex[u][t[i] - 'A'] == 0) u = link[u];
if (u == 0) u = 1, LEN = 0;
else {
LEN = len[u] + 1;
u = nex[u][t[i] - 'A'];
}
}
if(LEN >= k) {
ANS += (LL)nums[u] * (LEN - big(len[link[u]], k - 1));
if (len[link[u]]) lazy[link[u]] ++;
}
}
for(int i = cnt, x; i >= 1; --i) {
x = A[i];
if(len[x] >= k && lazy[x]) {
ANS += lazy[x] * nums[x] * (len[x] - big(len[link[x]], k - 1));
if(len[link[x]]) lazy[link[x]] += lazy[x];
}
}
}
void DEBUG() {
for (int i = cnt; i >= 1; i--) {
printf("nums[%d]=%d numt[%d]=%d len[%d]=%d link[%d]=%d\n", i, nums[i], i, nums[i], i, len[i], i, link[i]);
}
}
} sam; int main() {
#ifndef ONLINE_JUDGE
freopen("/home/cwolf9/CLionProjects/ccc/in.txt", "r", stdin);
//freopen("/home/cwolf9/CLionProjects/ccc/out.txt", "w", stdout);
#endif
while(~scanf("%d", &k) && k) {
scanf("%s%s", s, t);
n = strlen(s), m = strlen(t);
sam.clear();
sam.init_str(s);
sam.build(n);
ANS = 0;
sam.query();
for(int i = 0; i <= 2 * n + 5; ++i) lazy[i] = 0;
printf("%lld\n", ANS);
}
#ifndef ONLINE_JUDGE
cout << "time cost:" << clock() << "ms" << endl;
#endif
return 0;
}