题意
有一个打字机,支持三种操作:
- 字符串末尾加一个小写字母
- 字符串末尾减一个字符
- 输出这个字符串
经过不超过\(n\)次操作后有\(m\)组询问:\((x,y)\),表示第\(x\)次输出第字符串在第\(y\)次输出第字符串里出现几次
\(n,m \leq 10^5\)
题解
每次加减字符就在trie树上走,输出的话记录一下在哪个结点
然后考虑询问\((x,y)\)暴力怎么做:\(x\)应该是\(y\)一个前缀的后缀,于是我们对于从根到\(y\)路径上每个结点(这相当于枚举\(y\)的后缀),从这个结点跳\(fail\),如果跳到\(x\)就\(ans++\),然后考虑下一个结点
实际上我们要求的就是根到\(y\)这条链上的结点中,在\(fail\)树中是\(x\)儿子的个数
我们可以按\(\text{trie}\)树的\(\text{dfs}\)序枚举\(y\),这样枚举所有的链和信息是\(O(n)\)的,每个点只会被加入一次和删除一次。然后考虑回答所有\((i,y)\)的询问,直接询问当前在\(x\)的\(fail\)树子树的结点个数。可以使用树状数组维护。具体说就是把询问按\(y\)在\(trie\)上的\(\text{dfs}\)序排序,然后每个点必须插入到它\(fail\)树\(\text{dfs}\)序的位置,查询就找到\(x\)的\(fail\)子树的\(\text{dfs}\)区间进行查询。
实现的话注意\(trie\)和\(fail\)不要搞混了,另外这题可以用主席树在线做
#include <algorithm>
#include <cstdio>
#include <vector>
using namespace std;
const int N = 2e5 + 10;
int ch[N][26], fa[N], fail[N];
int dfn[N], dl[N], dr[N], dn[N];
int pos = 1, id = 1, n, pt[N], ans[N];
vector<int> fs[N];
struct qs {
int x, y, id;
bool operator < (const qs &b) const {
return dfn[y] < dfn[b.y];
}
} q[N];
void work(char c) {
if(c == 'B') pos = fa[pos];
else if(c == 'P') pt[++ pt[0]] = pos;
else {
int &v = ch[pos][c - 'a'];
if(!v) {
v = ++ id;
fa[v] = pos;
}
pos = v;
}
}
void dfs(int u) { //on trie
dfn[u] = ++ dfn[0]; dn[dfn[0]] = u;
for(int i = 0; i < 26; i ++)
if(ch[u][i]) dfs(ch[u][i]);
}
void buildac() {
static int q[N], l, r, v;
for(int i = 0; i < 26; i ++) if(v = ch[1][i]) {
q[r ++] = v; fail[v] = 1;
} else ch[1][i] = 1;
while(l < r) {
int u = q[l ++];
for(int i = 0; i < 26; i ++) if(v = ch[u][i]) {
q[r ++] = v; fail[v] = ch[fail[u]][i];
} else ch[u][i] = ch[fail[u]][i];
}
for(int i = 2; i <= id; i ++)
fs[fail[i]].push_back(i);
}
void dfs2(int u) { //on fail tree
dl[u] = ++ dl[0];
for(int i = 0; i < fs[u].size(); i ++) dfs2(fs[u][i]);
dr[u] = dl[0];
}
int bit[N];
void add(int x, int y) {
for(; x <= id; x += x & (-x)) bit[x] += y;
}
int qry(int x) {
int ans = 0;
for(; x >= 1; x &= x - 1) ans += bit[x];
return ans;
}
int main() {
static char s[N]; scanf("%s", s);
for(char *c = s; *c; c ++) work(*c);
dfs(1); buildac(); dfs2(1);
scanf("%d", &n);
for(int i = 1; i <= n; i ++) {
scanf("%d%d", &q[i].x, &q[i].y);
q[i].x = pt[q[i].x];
q[i].y = pt[q[i].y]; //id -> node
q[i].id = i;
}
sort(q + 1, q + n + 1);
for(int i = 1, j = 1; i <= id; i ++) {
int u = dn[i];
if(i > 1) {
int la = dn[i - 1];
while(la != fa[u]) {
add(dl[la], -1);
la = fa[la];
}
}
add(dl[u], 1);
for(; j <= n && dfn[q[j].y] == i; j ++) {
ans[q[j].id] = qry(dr[q[j].x]) - qry(dl[q[j].x] - 1);
}
}
for(int i = 1; i <= n; i ++)
printf("%d\n", ans[i]);
return 0;
}