Luogu3763 TJOI2017 DNA NTT/SA

时间:2023-12-14 16:57:14

传送门


两种做法:

①SA

将两个串拼在一次建立后缀数组,把\(height\)数组求出来,然后对于\(S\)中每一个长度为\(T\)的串和\(T\)暴力匹配,每一次找到最长的\(LCP\)匹配,如果失配次数\(>3\)就直接退出。总复杂度\(O(T(NlogN+4N))\)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
//This code is written by Itst
using namespace std;

const int MAXN = 2e5 + 7;
char s[MAXN];
int sa[MAXN] , rk[MAXN << 1] , tp[MAXN << 1] , pot[MAXN] , h[MAXN] , ST[19][MAXN];
int ls , L , maxN;

void sort(int p){
    memset(pot , 0 , sizeof(int) * (maxN + 1));
    for(int i = 1 ; i <= L ; ++i)
        ++pot[rk[i]];
    for(int i = 1 ; i <= maxN ; ++i)
        pot[i] += pot[i - 1];
    for(int i = 1 ; i <= L ; ++i)
        sa[++pot[rk[tp[i]] - 1]] = tp[i];
    memcpy(tp , rk , sizeof(int) * (L + 1));
    for(int i = 1 ; i <= L ; ++i)
        rk[sa[i]] = rk[sa[i - 1]] + (tp[sa[i]] != tp[sa[i - 1]] || tp[sa[i] + p] != tp[sa[i - 1] + p]);
    maxN = rk[sa[L]];
}

void init(){
    maxN = 26;
    for(int i = 1 ; i <= L ; ++i)
        rk[tp[i] = i] = s[i] - 'A' + 1;
    sort(0);
    for(int i = 1 ; maxN != L ; i <<= 1){
        int cnt = 0;
        for(int j = 1 ; j <= i ; ++j)
            tp[++cnt] = L - i + j;
        for(int j = 1 ; j <= L ; ++j)
            if(sa[j] > i)
                tp[++cnt] = sa[j] - i;
        sort(i);
    }
    for(int i = 1 ; i <= L ; ++i){
        if(rk[i] == 1)
            continue;
        int t = rk[i];
        h[t] = max(0 , h[rk[i - 1]] - 1);
        while(s[sa[t] + h[t]] == s[sa[t - 1] + h[t]])
            ++h[t];
    }
}

void init_ST(){
    for(int i = 2 ; i <= L ; ++i)
        ST[0][i] = h[i];
    for(int i = 1 ; (1 << i) + 1 <= L ; ++i)
        for(int j = 2 ; j + (1 << i) - 1 <= L ; ++j)
            ST[i][j] = min(ST[i - 1][j] , ST[i - 1][j + (1 << (i - 1))]);
}

inline int qST(int x , int y){
    if(x > y)
        swap(x , y);
    int t = log2(y - x);
    return min(ST[t][x + 1] , ST[t][y - (1 << t) + 1]);
}

int main(){
#ifndef ONLINE_JUDGE
    freopen("in","r",stdin);
    //freopen("out","w",stdout);
#endif
    int T;
    for(scanf("%d" , &T) ; T ; --T){
        scanf("%s" , s + 1);
        ls = strlen(s + 1);
        scanf("%s" , s + ls + 1);
        L = strlen(s + 1);
        init();
        init_ST();
        int ans = 0;
        for(int i = 1 ; i <= ls - (L - ls) + 1 ; ++i){
            int posS = i , posT = ls + 1 , cnt = 0;
            while(cnt <= 3 && posT <= L){
                int t = qST(rk[posS] , rk[posT]);
                posT += t;
                posS += t;
                if(posT > L)
                    break;
                ++cnt;
                ++posS;
                ++posT;
            }
            if(cnt <= 3)
                ++ans;
        }
        cout << ans << endl;
    }
    return 0;
}

②NTT

将模板串翻转,对于\(AGCT\)每一个做一次\(NTT\):如果匹配串第\(i\)位为当前字符则\(a_i=1\)否则\(a_i = 0\),模板串同理。然后NTT得到两个数组的卷积,就可得到匹配串每个位置的子串与模板串之间匹配字符为\(A\)的匹配次数。复杂度\(O(4TNlogN)\)

#include<iostream>
#include<cstdio>
#include<cctype>
#include<algorithm>
#include<cstring>
//This code is written by Itst
using namespace std;

const int G = 3 , MOD = 998244353 , INV = 332748118 , MAXN = (1 << 18) + 7;
const char exp[] = "AGCT";
int num[MAXN] , dir[MAXN] , sum[MAXN] , A[MAXN] , B[MAXN];
int need , inv_need , lS , lT;
char s[MAXN] , t[MAXN];

inline int poww(long long a , int b){
    int times = 1;
    while(b){
        if(b & 1)
            times = times * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return times;
}

void init(int x){
    need = 1;
    while(need < x)
        need <<= 1;
    inv_need = poww(need , MOD - 2);
    for(int i = 1 ; i <= need ; ++i)
        dir[i] = (dir[i >> 1] >> 1) | (i & 1 ? need >> 1 : 0);
}

void NTT(int *arr , int tp){
    for(int i = 1 ; i < need ; ++i)
        if(i < dir[i])
            arr[i] ^= arr[dir[i]] ^= arr[i] ^= arr[dir[i]];
    for(int i = 1 ; i < need ; i <<= 1){
        int wn = poww(tp == 1 ? G : INV , (MOD - 1) / i / 2);
        for(int j = 0 ; j < need ; j += i << 1){
            long long w = 1;
            for(int k = 0 ; k < i ; ++k , w = w * wn % MOD){
                int x = arr[j + k] , y = arr[i + j + k] * w % MOD;
                arr[j + k] = x + y >= MOD ? x + y - MOD : x + y;
                arr[i + j + k] = x < y ? x - y + MOD : x - y;
            }
        }
    }
}

int main(){
#ifndef ONLINE_JUDGE
    freopen("in","r",stdin);
    //freopen("out","w",stdout);
#endif
    int T;
    for(scanf("%d" , &T) ; T ; --T){
        scanf("%s %s" , s + 1 , t + 1);
        lS = strlen(s + 1);
        lT = strlen(t + 1);
        init(lS + lT);
        memset(sum , 0 , sizeof(int) * need);
        reverse(t + 1 , t + lT + 1);
        for(int j = 0 ; j < 4 ; ++j){
            memset(A , 0 , sizeof(int) * need);
            memset(B , 0 , sizeof(int) * need);
            char c = exp[j];
            for(int i = 1 ; i <= lS ; ++i)
                A[i] = s[i] == c;
            for(int i = 1 ; i <= lT ; ++i)
                B[i] = t[i] == c;
            NTT(A , 1); NTT(B , 1);
            for(int i = 0 ; i < need ; ++i)
                A[i] = 1ll * A[i] * B[i] % MOD;
            NTT(A , -1);
            for(int i = lT + 1 ; i <= lS + 1 ; ++i)
                sum[i] = sum[i] + A[i] >= MOD ? sum[i] + A[i] - MOD : sum[i] + A[i];
        }
        int cnt = 0;
        for(int i = lT + 1 ; i <= lS + 1 ; ++i)
            cnt += 1ll * sum[i] * inv_need % MOD >= lT - 3;
        cout << cnt << endl;
    }
    return 0;
}