HDU 4749 & POJ 3167 kmp变形

时间:2021-10-07 10:49:24

HDU4749

题意

给定一个主串 T 和模式串 P ,问 T 有多少个不重合的子串与 P 匹配

在这里,串 a 与串 b 匹配的含义是,

i,j,1i,jn,a[i]<a[j]b[i]<b[j]a[i]==a[j]b[i]==b[j]a[i]>a[j]b[i]>b[j]

参考

HDU 4749 && POJ 3167 KMP ——九野的博客

思路

最原始的匹配含义是两字符相同,因此比较时即比较 T[i]==P[j] ,现今改了匹配的含义,只需要对应的去修改比较方式即可。

那么怎么 check 当前的 T[i] P[j] 是否匹配呢?只需检查前面已经匹配成功的一段中分别比 T[i] 小的和比 P[j] 小的数的数量是否相同,以及相等的是否相同。

求失配数组也是同理。

因为 K 的范围只有 25 ,所以可以直接暴力预处理记录下 eq[i][j] lt[i][j] ,分别意为在第 i 个位置及以前与 j 相等的或者比 j 小的有多少个。

那么不重合的子串这个要求呢?我们可以贪心地去匹配,如果当前这段可行,就一定匹配上这段。否则,开始的位置更迟只会消耗掉之后其他子串进行匹配的可能性。一旦匹配上了,就再从模式串的开始位置进行匹配。

这道题要注意的问题是下标的问题,稍微有点绕; kmp 也不是最常规的那种写法。

是道好题。

Code

#include <bits/stdc++.h>
#define K 25
#define maxn 100010
using namespace std;
typedef long long LL;
int n, m, k;
int eq1[maxn][K+10], eq2[maxn][K+10], lt1[maxn][K+10], lt2[maxn][K+10], f[maxn], x[maxn], a[maxn];
bool cmp(int i, int j) { return x[i] < x[j]; }
void init(int* x, int eq[maxn][K+10], int lt[maxn][K+10], int n) {
    for (int i = 1; i <= n; ++i) {
        memcpy(eq[i], eq[i-1], sizeof(eq[0]));
        ++eq[i][x[i]];
    }
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= K; ++j) lt[i][j] = lt[i][j-1] + eq[i][j];
    }
}
bool match1(int i, int j, int x, int y) {
    return eq2[i][x] - eq2[i-j][x] == eq2[j][y] && lt2[i][x-1] - lt2[i-j][x-1] == lt2[j][y-1];
}
bool match2(int i, int j, int x, int y) {
    return eq1[i][x] - eq1[i-j][x] == eq2[j][y] && lt1[i][x-1] - lt1[i-j][x-1] == lt2[j][y-1];
}
void getfail(int* P) {
    int j = 0; f[1] = 0;
    for (int i = 1; i <= m; ) {
        if (!j || match1(i, j, P[i], P[j])) f[++i] = ++j;
        else j = f[j];
    }
}
int cnt;
void kmp(int* T, int* P) {
    int j = 1;
    for (int i = 0; i <= n; ) {
        if (!j || match2(i, j, T[i], P[j])) ++j, ++i;
        else j = f[j];
        if (j == m + 1) ++cnt, j = 1;
    }
}
void init() {
    memset(x, 0, sizeof(x)); memset(a, 0, sizeof(a)); memset(f, 0, sizeof(f));
    memset(eq1, 0, sizeof eq1); memset(eq2, 0, sizeof eq2);
    memset(lt1, 0, sizeof lt1); memset(lt2, 0, sizeof lt2);
}
void work() {
    init();
    for (int i = 1; i <= n; ++i) scanf("%d", &x[i]); init(x, eq1, lt1, n);
    for (int i = 1; i <= m; ++i) scanf("%d", &a[i]); init(a, eq2, lt2, m);
    getfail(a);
    int cur = 1, p; cnt = 0;
    kmp(x, a);
    printf("%d\n", cnt);
}
int main() {
    while (scanf("%d%d%d", &n, &m, &k) != EOF) work();
    return 0;
}

POJ 3167

题意

基本同上,子串可以重复

思路

基本同上。

匹配成功后的跳转仍然借助失配数组进行跳转。

Code

#include <cstdio>
#include <iostream>
#include <cstring>
#include <vector>
#define K 25
#define maxn 100010
using namespace std;
typedef long long LL;
vector<int> ans;
int n, m, k, cnt;
int eq1[maxn][K+10], eq2[maxn][K+10], lt1[maxn][K+10], lt2[maxn][K+10], f[maxn], x[maxn], a[maxn];
bool cmp(int i, int j) { return x[i] < x[j]; }
void init(int* x, int eq[maxn][K+10], int lt[maxn][K+10], int n) {
    for (int i = 1; i <= n; ++i) {
        memcpy(eq[i], eq[i-1], sizeof(eq[0]));
        ++eq[i][x[i]];
    }
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= K; ++j) lt[i][j] = lt[i][j-1] + eq[i][j];
    }
}
bool match1(int i, int j, int x, int y) {
    return eq2[i][x] - eq2[i-j][x] == eq2[j][y] && lt2[i][x-1] - lt2[i-j][x-1] == lt2[j][y-1];
}
bool match2(int i, int j, int x, int y) {
    return eq1[i][x] - eq1[i-j][x] == eq2[j][y] && lt1[i][x-1] - lt1[i-j][x-1] == lt2[j][y-1];
}
void getfail(int* P) {
    int j = 0; f[1] = 0;
    for (int i = 1; i <= m; ) {
        if (!j || match1(i, j, P[i], P[j])) f[++i] = ++j;
        else j = f[j];
    }
}
void kmp(int* T, int* P) {
    int j = 1;
    for (int i = 1; i <= n; ) {
        if (!j || match2(i, j, T[i], P[j])) ++j, ++i;
        else j = f[j];
        if (j == m + 1) {
            ans.push_back(i-m);
            ++cnt;
            j = f[j];
        }
    }
}
void init() {
    memset(x, 0, sizeof(x)); memset(a, 0, sizeof(a)); memset(f, 0, sizeof(f));
    memset(eq1, 0, sizeof eq1); memset(eq2, 0, sizeof eq2);
    memset(lt1, 0, sizeof lt1); memset(lt2, 0, sizeof lt2);
    ans.clear();
}
void work() {
    init();
    for (int i = 1; i <= n; ++i) scanf("%d", &x[i]); init(x, eq1, lt1, n);
    for (int i = 1; i <= m; ++i) scanf("%d", &a[i]); init(a, eq2, lt2, m);
    getfail(a);
    cnt = 0;
    kmp(x, a);
    printf("%d\n", cnt);
    for (vector<int>::iterator it = ans.begin(); it != ans.end(); ++it) printf("%d\n", *it);
}
int main() {
    while (scanf("%d%d%d", &n, &m, &k) != EOF) work();
    return 0;
}