CF914E Palindromes in a Tree(点分治)

时间:2022-09-01 05:56:46

link

题目大意:给定一个n个点的树,每个点都有一个字符(a-t,20个字符)

我们称一个路径是神犇的,当这个路径上所有点的字母的某个排列是回文

求出对于每个点,求出经过他的神犇路径的数量

题解:

对于回文串,我们发现最多允许1个字母出现了奇数次,和%2有关

并且由于只有20个字母,说到20我就想起了二进制状压,我们对于一条链状压成20维的01向量,表示某个字符出现的次数是奇数还是偶数

说到树上静态问题我就想起淀粉质

我们考虑静态淀粉质,对于当前的树我们找出他的重心rt,然后对于每个子树DFS一下,求出某个子树到rt的路径上的所有权值(开个桶)

然后对于某个子树,在桶内减去这个子树对应的权值之后,xjb统计一下有多少个点到他路径是合法的,打个标记,然后在rt为根的树搞个树上差分就行了

然后点分下去就行了

#include <cstdio>
#include <vector>
using namespace std;

int n, sz[200010], mxsz[200010], col[200010], sum, rt;
int bucket[1050000];
long long ans[200010], tmp[200010];
char str[200010];
bool vis[200010];
vector<int> out[200010];

void chkmax(int &a, int b) { if (a < b) a = b; }

void getrt(int x, int fa)
{
    sz[x] = 1, mxsz[x] = 0;
    for (int i : out[x]) if (vis[i] == false && i != fa)
        getrt(i, x), sz[x] += sz[i], chkmax(mxsz[x], sz[i]);
    chkmax(mxsz[x], sum - sz[x]);
    if (mxsz[x] < mxsz[rt]) rt = x;
}

void getval(int x, int fa, int flag, int dis)
{
    bucket[dis ^ col[x]] += flag;
    for (int i : out[x]) if (vis[i] == false && fa != i) getval(i, x, flag, dis ^ col[x]);
}

void qsum(int x, int fa, int dis)
{
    int cur = dis ^ col[x]; tmp[x] = bucket[cur];
    for (int i = 0; i < 20; i++) tmp[x] += bucket[cur ^ (1 << i)];
    for (int i : out[x]) if (vis[i] == false && fa != i) qsum(i, x, cur);
}

void qdis(int x, int fa)
{
    for (int i : out[x]) if (vis[i] == false && fa != i) qdis(i, x), tmp[x] += tmp[i];
    ans[x] += tmp[x] / ((x == rt) + 1);
}

void solve(int x)
{
    vis[x] = true;
    for (int i : out[x]) if (vis[i] == false) getval(i, x, 1, 0);
    bucket[0]++;
    for (int i : out[x]) if (vis[i] == false)
        getval(i, x, -1, 0), qsum(i, x, col[x]), getval(i, x, 1, 0);
    bucket[0]--;
    tmp[x] = bucket[col[x]];
    for (int i = 0; i < 20; i++) tmp[x] += bucket[col[x] ^ (1 << i)];
    qdis(x, 0);
    for (int i : out[x]) if (vis[i] == false) getval(i, x, -1, 0);
    for (int i : out[x]) if (vis[i] == false) rt = 0, sum = sz[i], getrt(i, 0), solve(rt);
}

int main()
{
    scanf("%d", &n);
    for (int x, y, i = 1; i < n; i++)
        scanf("%d%d", &x, &y), out[x].push_back(y), out[y].push_back(x);
    scanf("%s", str + 1);
    for (int i = 1; i <= n; i++) col[i] = (1 << (str[i] - 97));
    mxsz[0] = sum = n, getrt(1, 0), solve(rt);
    for (int i = 1; i <= n; i++) printf("%lld ", ans[i] + 1);
    return 0;
}