题目链接:http://poj.org/problem?id=3376
题意:给你n个字符串m1、m2、m3...mn 求S = mimj(1=<i,j<=n)是回文串的数量
思路:我们考虑第i个字符串和第j个字符串能构成组合回文串要满足的条件:
1、i的长度小于j,那么i一定是j的反串的前缀,且j的反串剩下的后缀是回文串
2、i的长度等于j,那么i等于j的反串
3、i的长度大于j,那么j的反串一定是i的前缀,且i串剩下的后缀是回文串
我们可以将这n个字符串插入trie,每个节点要维护两个值:value1. 到当前节点的字符串个数;value2. 当前节点后面的回文子串个数
我们用每个字符串的反串去trie上查找,要构成回文串有以下情况:
1、 此反串是其他串的前缀,那么组合回文串的数量就要加上value2
2、此反串的前缀是某些字符串,且反串剩下的后缀是回文串,那么组合回文串的数量要加上value1
3、2的特例:此反串的前缀是某些字符串,且反串剩下的后缀为空,同样要加上value1,这种情况可以和2一起处理
关键:
1、判断字符串的哪些后缀是回文串(用于更新value2),以及对应反串的哪些后缀是回文串(当面临第二种情况时,可直接判断后缀否为回文串)
2、如何更新value1和value2(借助1的结果)
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int MAXN = ;
const int KIND = ; struct TrieNode
{
int num; // 到当前节点的字符串个数
int cnt; // 当前节点后面回文子串个数
TrieNode* nxt[];
}; TrieNode node[MAXN]; // 避免动态申请空间的时间消耗
TrieNode* root; // trie树的根节点
int bg[MAXN]; // bg[i]第i+1个字符串开始的位置
int ed[MAXN]; // ed[i]第i+1个字符串结束的位置
bool flag[][MAXN]; // flag[0][i]为true表示原串后面为回文串 flag[1][i]表示反串
char S[MAXN]; // 存放原串
char T[MAXN]; // 存放反串
int nxt[MAXN]; // 存放next数组
int extend[MAXN]; // 用于判断是否为回文子串
LL ans; // 保存结果
int tot; // node数组的下标 void GetNext(char* T, int lhs, int rhs)
{
int j = ;
while (lhs + j + <= rhs && T[lhs + j] == T[lhs + j + ]) ++j;
nxt[lhs + ] = j;
int k = lhs + ;
for (int i = lhs + ; i <= rhs; ++i)
{
int p = nxt[k] + k - ;
int L = nxt[lhs + i - k];
if (L + i < p + ) nxt[i] = L;
else
{
j = max(, p - i + );
while (i + j <= rhs && T[lhs + j] == T[i + j]) ++j;
nxt[i] = j;
k = i;
}
}
} void ExtendKMP(char* S, char* T, int lhs, int rhs, bool sign)
{
GetNext(T, lhs, rhs);
int j = ;
while (j + lhs <= rhs && S[j + lhs] == T[j + lhs]) ++j;
extend[lhs] = j;
int k = lhs;
for (int i = lhs + ; i <= rhs; ++i)
{
int p = extend[k] + k - ;
int L = nxt[lhs + i - k];
if (L + i < p + ) extend[i] = L;
else
{
j = max(, p - i + );
while (i + j <= rhs && S[i + j] == T[lhs + j]) ++j;
extend[i] = j;
k = i;
}
}
for (int i = lhs; i <= rhs; ++i)
{
if (extend[i] == rhs - i + )
flag[sign][i] = true;
}
} void Insert(char S[], int lhs, int rhs)
{
TrieNode* temp = root;
for (int i = lhs; i <= rhs; ++i)
{
int ch = S[i] - 'a';
temp->cnt += flag[][i]; // 更新当前节点后面回文子串的数目
if (temp->nxt[ch] == NULL) temp->nxt[ch] = &node[tot++];
temp = temp->nxt[ch];
}
++temp->num; // 更新到当前节点的字符串数目
} void Search(char S[], int lhs, int rhs)
{
TrieNode* temp = root;
for (int i = lhs; i <= rhs; ++i)
{
int ch = S[i] - 'a';
temp = temp->nxt[ch];
if (temp == NULL) break;
if ((i < rhs && flag[][i + ]) || i == rhs)
ans += temp->num;
}
if (temp) ans += temp->cnt;
} int main()
{
int n;
while (scanf("%d", &n) != EOF)
{
// 初始化
tot = ;
ans = ;
memset(node, , sizeof(node));
memset(flag, , sizeof(flag));
root = &node[tot++]; int l = ;
int L = ;
for (int i = ; i < n; ++i)
{
// 输入一组数据
scanf("%d", &l);
scanf("%s", S + L); // 生成反串
for (int j = ; j < l; ++j)
T[L + j] = S[L + l - - j]; bg[i] = L;
ed[i] = L + l - ; ExtendKMP(S, T , bg[i], ed[i], );
ExtendKMP(T, S , bg[i], ed[i], );
Insert(S, bg[i], ed[i]); L += l;
} for (int i = ; i < n; ++i)
Search(T, bg[i], ed[i]); printf("%lld\n", ans);
}
return ;
}