Description
题目大意:给定序列 $\{a_n\}$ 和正整数 $k$,求前 $k$ 大区间异或和的和。$1 \le n \le 5 \times 10^5, k \le 2 \times 10^5$。
Solution 1
首先对序列做一个前缀和,问题就变为求前 $k$ 大两个数的异或和。
我们考虑维护一个大根堆,堆中的每个元素都储存了一段区间的信息,包括当前区间的端点 $l, r$,当前区间的异或和和当前区间左端点所在的区间范围 $L, R$。
然后我们进行 $k$ 次操作,每次操作都把堆顶区间的异或和加到答案上,然后把 $[L, R]$ 分裂成 $[L, l - 1]$ 和 $[l + 1, R]$,把这两段区间的信息加入堆,这个操作可以用可持久化 Trie 完成。
每次操作最多会把一个元素加入堆中,时间复杂度 $O(k \log (n + k) \log w)$。
Solution 2
很巧妙的做法,不需要可持久化 Trie,只需要普通的 01 Trie。
原问题等价于求前 $k$ 大的 $s_i \oplus s_j(i < j)$,我们可以求前 $2k$ 大的 $s_i \oplus s_j$,然后再将答案除以 $2$。
考虑一个 $(n + 1) \times (n + 1)$ 的表,要求前 $k$ 大的 $s_i \oplus s_j$,假设我们已经对每行从大到小排序,取出堆顶 $s_{i, j}$ 之后只需要将 $s_{i, j + 1}$ 放入堆中即可,而求异或第 $k$ 大是可以在 Trie 上跑一遍完成的。
时间复杂度 $O(k \log n \log w)$。
Code
#include <bits/stdc++.h>
template <class T>
inline void read(T &x) {
x = 0;
int f = 0;
char ch = getchar();
while (!isdigit(ch)) { f |= ch == '-'; ch = getchar(); }
while (isdigit(ch)) { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }
x = f ? -x : x;
return ;
}
typedef unsigned long long uLL;
typedef long long LL;
constexpr int maxn = 5e5 + 10;
struct Trie {
int tr[maxn << 6][2];
int val[maxn << 6], rt[maxn], id[maxn << 6];
int tot;
Trie() {
tot = 0;
memset(id, 0, sizeof(id));
memset(rt, 0, sizeof(rt));
memset(val, 0, sizeof(val));
memset(tr, 0, sizeof(tr));
}
void insert(int x, int now, int lst, LL v) {
val[now] = val[lst] + 1;
for (int i = 32; i >= 0; --i) {
int w = (v >> i) & 1;
tr[now][0] = tr[lst][0], tr[now][1] = tr[lst][1];
tr[now][w] = ++tot;
now = tr[now][w], lst = tr[lst][w];
val[now] = val[lst] + 1;
}
id[now] = x;
return ;
}
int query(int l, int r, LL v) {
for (int i = 32; i >= 0; --i) {
int w = (v >> i) & 1;
if (val[tr[r][w ^ 1]] - val[tr[l][w ^ 1]]) {
l = tr[l][w ^ 1], r = tr[r][w ^ 1];
} else {
l = tr[l][w], r = tr[r][w];
}
}
return id[r];
}
} t;
struct Node {
int l, r, x, p;
LL v;
friend bool operator < (const Node &a, const Node &b) {
return a.v < b.v;
}
};
std::priority_queue<Node> q;
LL a[maxn];
LL ans;
int n, k;
int main() {
read(n), read(k);
t.rt[1] = ++t.tot;
t.insert(1, t.rt[1], t.rt[0], 0);
for (int i = 2; i <= n + 1; ++i) {
read(a[i]);
a[i] ^= a[i - 1];
t.rt[i] = ++t.tot;
t.insert(i, t.rt[i], t.rt[i - 1], a[i]);
int j = t.query(t.rt[0], t.rt[i - 1], a[i]);
q.push((Node){0, i - 1, i, j, a[i] ^ a[j]});
}
while (k--) {
Node now = q.top();
q.pop();
ans += now.v;
if (now.l + 1 < now.p) {
int p = t.query(t.rt[now.l], t.rt[now.p - 1], a[now.x]);
q.push((Node){now.l, now.p - 1, now.x, p, a[now.x] ^ a[p]});
}
if (now.r > now.p) {
int p = t.query(t.rt[now.p], t.rt[now.r], a[now.x]);
q.push((Node){now.p, now.r, now.x, p, a[now.x] ^ a[p]});
}
}
printf("%lld\n", ans);
return 0;
}