设一个串\(s\)在\(A\)中出现\(cnt[s][1]\)次,在\(B\)中出现\(cnt[s][2]\)次,我们要求的就是:
$$\sum cnt[s][1]*cnt[s][2]$$
在\(SAM\)这种把多个串用一个点表示的东西里,答案就变成了这个
$$\sum cnt[s][1] * cnt[s][2] * (len[fa[s]]-len[s])$$
其中的\(cnt\)求法,听说好像可以两个串隔开求?但是我不太会。学了一下用广义\(SAM\)的写法,似乎是第一个串建完之后把\(las\)指针指回根节点,再建第二个就好。因为网上对于这种写法各种声音都有,所以我打算这周末认真学习\(SA\)和\(SAM\)后再详细进行解释说明或者算法更正。
\(p.s\)这种写法下似乎不能以\(len\)桶排序求\(cnt\),因为\(len\)会有相等情况。所以我们要用\(Parent\) \(Tree\)上\(DP\)来写。
最后提醒:别忘\(long\) \(long\)
#include <bits/stdc++.h>
using namespace std;
const int N = 800010;
typedef long long ll;
ll tot[2][N];
int node = 1, las = 1;
int fa[N], len[N], ch[N][26];
void extend (int c, int id) {
int p = las, q = ++node;
len[q] = len[p] + 1, tot[id][q] = 1, las = q;
while (p != 0 && ch[p][c] == 0) {
ch[p][c] = q;
p = fa[p];
}
if (p == 0) {
fa[q] = 1;
} else {
int x = ch[p][c];
if (len[x] == len[p] + 1) {
fa[q] = x;
} else {
int y = ++node;
len[y] = len[p] + 1;
fa[y] = fa[x];
fa[x] = fa[q] = y;
memcpy (ch[y], ch[x], sizeof (ch[x]));
while (p != 0 && ch[p][c] == x) {
ch[p][c] = y;
p = fa[p];
}
}
}
}
int cnt, head[N];
struct edge {
int nxt, to;
}e[N];
void add_edge (int from, int to) {
e[++cnt].nxt = head[from];
e[cnt].to = to;
head[from] = cnt;
}
void dfs (int u) {
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
dfs (v);
tot[0][u] += tot[0][v];
tot[1][u] += tot[1][v];
}
}
ll get_ans () {
ll ans = 0;
for (int i = 1; i <= node; ++i) add_edge (fa[i], i); dfs (1);
for (int i = 1; i <= node; ++i) ans += 1LL * (len[i] - len[fa[i]]) * tot[0][i] * tot[1][i];
return ans;
}
int n1, n2;
char s1[N], s2[N];
int main () {
scanf ("%s %s", s1, s2);
n1 = strlen (s1), n2 = strlen (s2);
for (int i = 0; i < n1; ++i) extend (s1[i] - 'a', 0);
las = 1;
for (int i = 0; i < n2; ++i) extend (s2[i] - 'a', 1);
cout << get_ans () << endl;
}