【bzoj数据下载地址】不要谢我
先讲一下窝是怎么错的。。。
\(MLE\)是因为数组开小了。。
看到异或和最大,那么就会想到用线性基。
如果不会线性基的可以参考一下我的学习笔记:「线性基」学习笔记and乱口*结
但是这一道题目需要合并线性基。
如何合并线性基?
不需要什么花里胡哨的操作,直接暴力插入就可以了。
void merge(xxj &x, xxj y) {
for (int i = BIT; ~i; i --)
if (y.p[i]) x.ins(y.p[i]);
}
代码中的\(x\)和\(y\)是两个线性基。
原理就是把\(y\)中的每一个元素插入到\(x\)中。
然后再套一个倍增求\(LCA\)就可以了。
代码
#include <bits/stdc++.h>
#define gc getchar
using namespace std;
typedef long long ll;
const int BIT = 62;
const int LOG = 21;
const int N = 40005;
template <typename T> void read(T &x) {
x = 0; T fl = 1; char c = 0;
for (; c < '0' || c > '9'; c = gc())
if (c == '-') fl = -1;
for (; c >= '0' && c <= '9'; c = gc())
x = (x << 1) + (x << 3) + (c ^ 48);
x *= fl;
}
struct xxj {
ll p[BIT + 2];
void clear() { memset(p, 0, sizeof(p)); }
void ins(ll x) {
for (int i = BIT; ~i; i --) {
if ((x >> i) == 0) continue;
if (!p[i]) { p[i] = x; break; }
x ^= p[i];
}
}
} g[N][LOG + 2], ans;
struct edge {
int to, nt;
} E[N];
int fa[N][LOG + 2];
int n, ecnt, Q;
int dep[N], H[N];
void add_edge(int u, int v) {
E[++ ecnt] = (edge){v, H[u]};
H[u] = ecnt;
}
void merge(xxj &x, xxj y) {
for (int i = BIT; ~i; i --)
if (y.p[i]) x.ins(y.p[i]);
}
void dfs(int u, int ft) {
fa[u][0] = ft; dep[u] = dep[ft] + 1;
for (int i = 1; i <= LOG; i ++) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];
g[u][i] = g[u][i - 1];
merge(g[u][i], g[fa[u][i - 1]][i - 1]);
}
for (int e = H[u]; e; e = E[e].nt) {
int v = E[e].to;
if (v == fa[u][0]) continue;
dfs(v, u);
}
}
void Lca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
for (int i = LOG; ~i; i --)
if (dep[fa[u][i]] >= dep[v])
merge(ans, g[u][i]), u = fa[u][i];
if (u == v) {
merge(ans, g[u][0]);
return;
}
for (int i = LOG; ~i; i --) {
if (fa[u][i] != fa[v][i]) {
merge(ans, g[u][i]);
merge(ans, g[v][i]);
u = fa[u][i]; v = fa[v][i];
}
}
merge(ans, g[u][0]);
merge(ans, g[v][0]);
merge(ans, g[fa[u][0]][0]);
}
int main() {
read(n); read(Q);
for (int i = 1; i <= n; i ++) {
ll x; read(x);
g[i][0].ins(x);
}
for (int i = 1, u, v; i < n; i ++) {
read(u); read(v);
add_edge(u, v);
add_edge(v, u);
}
dfs(1, 0);
while (Q --) {
int u, v; read(u); read(v);
ans.clear();
Lca(u, v);
ll res = 0ll;
for (int i = BIT; ~i; i --)
if ((res ^ ans.p[i]) > res) res ^= ans.p[i];
cout << res << endl;
}
return 0;
}