hdu4670(树上点分治+状态压缩)

时间:2023-03-09 09:32:26
hdu4670(树上点分治+状态压缩)

树上路径的f(u,v)=路径上所有点的乘积。

树上每个点的权值都是由给定的k个素数组合而成的,如果f(u,v)是立方数,那么就说明f(u,v)是可行的方案。

问有多少种可行的方案。

f(u,v)可是用状态压缩来表示,因为最多只有30个素数, 第i位表示第i个素数的幂,那么每一位的状态只有0,1,2因为3和0是等价的,所以用3进制状态来表示就行了。

其他代码就是裸的树分。

另外要注意的是,因为counts函数没有统计只有一个点的情况,所以需要另外统计。

 #pragma warning(disable:4996)
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <math.h>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <vector>
#include <bitset>
#include <algorithm>
#include <iostream>
#include <string>
#include <functional>
#include <unordered_map>
const int INF = << ;
typedef __int64 LL;
/*
用三进制的每一位表示第i个素数的幂
如果幂都是0,那么说明是立方
*/
const int N = + ;
std::vector<int> g[N];
std::unordered_map<LL, int> mp;
struct Node
{
int sta[];
}node[N];
LL prime[];
std::vector<Node> dist;
int n, k;
int size[N], vis[N], total, root, mins;
LL _3bit[];
void init()
{
_3bit[] = ;
for (int i = ;i <= ;++i)
_3bit[i] = _3bit[i - ] * ;
}
void getRoot(int u, int fa)
{
int maxs = ;
size[u] = ;
for (int i = ;i < g[u].size();++i)
{
int v = g[u][i];
if (v == fa || vis[v]) continue;
getRoot(v, u);
size[u] += size[v];
maxs = std::max(maxs, size[v]);
}
maxs = std::max(maxs, total - size[u]);
if (mins > maxs)
{
mins = maxs;
root = u;
}
}
void getDis(int u, int fa, Node d)
{
dist.push_back(d);
for (int i = ;i < g[u].size();++i)
{
int v = g[u][i];
if (v == fa || vis[v]) continue;
Node tmp;
for (int j = ;j < k;++j)
tmp.sta[j] = (d.sta[j] + node[v].sta[j]) % ;
getDis(v, u, tmp);
}
}
LL counts(int u)//计算经过u点的路径
{
mp.clear();
mp[] = ;
LL ret = ;
for (int i = ;i < g[u].size();++i)
{
int v = g[u][i];
if (vis[v]) continue;
dist.clear();
getDis(v, u, node[v]);
for (int j = ;j < dist.size();++j)
{
LL sta = ;
for (int z = ;z < k;++z)
{
sta += ( - (node[u].sta[z] + dist[j].sta[z]) % ) % * _3bit[z];
}
ret += mp[sta];
}
for (int j = ;j < dist.size();++j)
{
LL sta = ;
for (int z = ;z < k;++z)
sta += dist[j].sta[z] * _3bit[z];
mp[sta]++;
}
}
return ret;
}
LL ans;
void go(int u)
{
vis[u] = true;
ans += counts(u);
for (int i = ;i < g[u].size(); ++i)
{
int v = g[u][i];
if (vis[v]) continue;
total = size[v];
mins = INF;
getRoot(v, u);
go(root);
} }
int main()
{
int u, v;
LL x;
init();
while (scanf("%d%d", &n, &k) != EOF)
{
for (int i = ;i < k;++i)
scanf("%I64d", &prime[i]);
ans = ;
for (int i = ;i <= n;++i)
{
g[i].clear();
vis[i] = ;
scanf("%I64d", &x);
memset(node[i].sta, , sizeof(node[i].sta));
int tmp = ;
for (int j = ;j <k;++j)
{ while (x%prime[j] == && x)
{
node[i].sta[j]++;
x /= prime[j];
}
node[i].sta[j] %= ;
if (node[i].sta[j] != )tmp++;
}
if (tmp == )//统计只有一个点的
ans++;
}
for (int i = ;i < n;++i)
{
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
total = n;
mins = INF;
getRoot(, -);
go(root);
printf("%I64d\n", ans);
}
return ;
}