Description
题目大意:给定一棵树,有两种询问:
- 给定 $x, z$,求 $x$ 子树内的点与 $z$ 异或的最大值。
- 给定 $x, y, z$,求 $x$ 到 $y$ 的路径上的点与 $z$ 异或的最大值。
Solution
第一个操作是子树操作,考虑按照 dfs 序处理。
第二个操作看起来要用树剖,实际上拆成两条链处理就行了。
维护两棵可持久化 Trie,一棵按照 dfs 序建,一棵按照树上的父子关系建,同时要写一个东西求 LCA。
Code
#include <bits/stdc++.h>
template <class T>
inline void read(T &x) {
x = 0;
int f = 0;
char ch = getchar();
while (!isdigit(ch)) { f |= ch == '-'; ch = getchar(); }
while (isdigit(ch)) { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }
x = f ? -x : x;
return ;
}
typedef unsigned long long uLL;
typedef long long LL;
constexpr int maxn = 1e5 + 10;
struct Trie {
int tr[maxn << 5][2];
int val[maxn << 5];
int rt[maxn];
int tot;
Trie() {
tot = 0;
memset(rt, 0, sizeof(rt));
memset(tr, 0, sizeof(tr));
memset(val, 0, sizeof(val));
}
void insert(int now, int lst, int v) {
val[now] = val[lst] + 1;
for (int i = 30; i >= 0; --i) {
int w = (v >> i) & 1;
tr[now][0] = tr[lst][0], tr[now][1] = tr[lst][1];
tr[now][w] = ++tot;
now = tr[now][w], lst = tr[lst][w];
val[now] = val[lst] + 1;
}
return ;
}
int query(int l, int r, int v) {
int res = 0;
for (int i = 30; i >= 0; --i) {
int w = (v >> i) & 1;
if (val[tr[r][w ^ 1]] - val[tr[l][w ^ 1]]) {
res |= (1 << i);
l = tr[l][w ^ 1], r = tr[r][w ^ 1];
} else {
l = tr[l][w], r = tr[r][w];
}
}
return res;
}
} t1, t2;
std::vector<int> g[maxn];
int a[maxn], in[maxn], out[maxn], w[maxn], son[maxn], dep[maxn], fa[maxn], top[maxn], siz[maxn];
int n, q, cnt;
void dfs(int x, int p) {
fa[x] = p;
dep[x] = dep[p] + 1;
siz[x] = 1;
in[x] = ++cnt;
w[cnt] = a[x];
t2.rt[x] = ++t2.tot;
t2.insert(t2.rt[x], t2.rt[p], a[x]);
for (auto i : g[x]) {
if (i != p) {
dfs(i, x);
siz[x] += siz[i];
if (siz[i] > siz[son[x]]) son[x] = i;
}
}
out[x] = cnt;
return ;
}
void dfs2(int x, int p) {
top[x] = p;
if (son[x]) {
dfs2(son[x], p);
for (auto i : g[x]) {
if (i != fa[x] && i != son[x]) {
dfs2(i, i);
}
}
}
}
int lca(int u, int v) {
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) std::swap(u, v);
u = fa[top[u]];
}
return dep[u] > dep[v] ? v : u;
}
int main() {
read(n), read(q);
for (int i = 1; i <= n; ++i) read(a[i]);
for (int i = 1, u, v; i < n; ++i) {
read(u), read(v);
g[u].push_back(v), g[v].push_back(u);
}
dfs(1, 0);
dfs2(1, 1);
for (int i = 1; i <= n; ++i) {
t1.rt[i] = ++t1.tot;
t1.insert(t1.rt[i], t1.rt[i - 1], w[i]);
}
while (q--){
int op, x, y, z;
read(op), read(x), read(y);
if (op == 1) {
printf("%d\n", t1.query(t1.rt[in[x] - 1], t1.rt[out[x]], y));
} else {
read(z);
int p = fa[lca(x, y)];
printf("%d\n", std::max(t2.query(t2.rt[p], t2.rt[x], z), t2.query(t2.rt[p], t2.rt[y], z)));
}
}
return 0;
}