HDU 4812 D Tree

时间:2025-04-13 18:36:19

HDU 4812

思路:

点分治

先预处理好1e6 + 3以内到逆元

然后用map 映射以分治点为起点的链的值a 成他的下标 u

然后暴力跑出以分治点儿子为起点的链的值b,然后在map里查找inv[b]*k

代码:

#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
//#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pii pair<int, int>
#define mem(a, b) memset(a, b, sizeof(a))
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);
//head const int MOD = 1e6 + ;
const int INF = 0x7f7f7f7f;
const int N = 1e5 + ;
int inv[MOD + ], mp[MOD + ], head[N], mxsz[N], sz[N], v[N], cnt = , rt = , n, k, ans1, ans2;
int deep[N], dis[N], id[N], top = ;
bool vis[N];
struct edge {
int to, nxt;
}edge[N*];
void add_edge(int u, int v) {
edge[cnt].to = v;
edge[cnt].nxt = head[u];
head[u] = cnt++;
}
void init() {
inv[] = ;
for (int i = ; i < MOD; i++) inv[i] = (MOD - MOD/i) * 1LL * inv[MOD%i] % MOD;
}
void update(int x, int y) {
int t = (1LL * inv[x] * k) % MOD;
int now = mp[t];
if(!now) return ;
if(now > y) swap(now, y);
if(now < ans1 || now == ans1 && y < ans2) ans1 = now, ans2 = y;
}
void get_rt(int o, int u) {
sz[u] = , mxsz[u] = ;
for (int i = head[u]; ~i; i = edge[i].nxt) {
if(edge[i].to != o && !vis[edge[i].to]) {
get_rt(u, edge[i].to);
sz[u] += sz[edge[i].to];
mxsz[u] = max(mxsz[u], sz[edge[i].to]);
}
}
mxsz[u] = max(mxsz[u], n - sz[u]);
if(mxsz[u] < mxsz[rt]) rt = u;
}
void get_d(int o, int u) {
deep[++top] = dis[u];
id[top] = u;
for (int i = head[u]; ~i; i = edge[i].nxt) {
if(!vis[edge[i].to] && edge[i].to != o) {
dis[edge[i].to] = (1LL * dis[u] * v[edge[i].to])%MOD;
get_d(u, edge[i].to);
}
}
}
void solve(int u) {
vis[u] = true;
mp[v[u]] = u;
for (int i = head[u]; ~i; i = edge[i].nxt) {
if(!vis[edge[i].to]) {
top = , dis[edge[i].to] = v[edge[i].to];
get_d(u, edge[i].to);
for (int j = ; j <= top; j++) update(deep[j], id[j]);
top = , dis[edge[i].to] = (1LL * v[u] * v[edge[i].to])%MOD;
get_d(u, edge[i].to);
for (int j = ; j <= top; j++) {
int t = deep[j];
if(!mp[t] || id[j] < mp[t]) mp[t] = id[j];
}
}
}
mp[v[u]] = ;
for (int i = head[u]; ~i; i = edge[i].nxt) {
if(!vis[edge[i].to]) {
top = , dis[edge[i].to] = (1LL * v[u] * v[edge[i].to])%MOD;
get_d(u, edge[i].to);
for (int j = ; j <= top; j++) mp[deep[j]] = ;
}
}
for (int i = head[u]; ~i; i = edge[i].nxt) {
if(!vis[edge[i].to]) {
mxsz[] = n = sz[edge[i].to];
get_rt(rt = , edge[i].to);
solve(rt);
}
}
}
int main() {
init();
int u, V;
while(~scanf("%d%d", &n, &k)) {
mem(head, -);
mem(vis, false);
mem(mp, );
cnt = ;
ans1 = ans2 = INF;
for (int i = ; i <= n; i++) scanf("%d", &v[i]);
for (int i = ; i < n; i++) scanf("%d%d", &u, &V), add_edge(u, V), add_edge(V, u);
mxsz[] = n;
get_rt(rt = , );
solve(rt);
if(ans1 == INF) printf("No solution\n");
else printf("%d %d\n", ans1, ans2);
}
return ;
}