[LOJ3048] [十二省联考2019] 异或粽子

时间:2022-02-27 21:29:20

题目链接

LOJ:https://loj.ac/problem/3048

洛谷:https://www.luogu.org/problemnew/show/P5283

Solution

考虑每个子串都是一个前缀的后缀,我们可以用堆维护四元组\((l,r,ed,pos)\)表示当前右端点为\(ed\),左端点范围是\([l,r]\),其中\([pos+1,ed]\)这个区间的异或和最大。

这其实就是固定了前缀来找后缀。那么我们每次可以从堆顶拿一个四元组出来,然后分裂成两个:\((l,pos-1,ed,\cdots),(pos+1,r,ed,\cdots)\),顺便更新答案。其中省略号的部分可以通过可持久化\(\rm 01trie\)树算出来。

复杂度\(O(n\log n+k\log n)\)。

#include<bits/stdc++.h>
using namespace std; #define int unsigned int template <class T> void read(T &x) {
x=0;T f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
} template <class T> void print(T x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
template <class T> void write(T x) {if(!x) putchar('0');else print(x);putchar('\n');} #define lf double
#define ll long long #define pii pair<int,int >
#define vec vector<int > #define pb push_back
#define mp make_pair
#define fr first
#define sc second #define FOR(i,l,r) for(int i=l,i##_r=r;i<=i##_r;i++) const int maxn = 5e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 1e9+7;
const int maxm = 2e7+10; int a[maxn],n,k; struct trie {
int son[maxm][2],tot,cnt[maxm],id[maxm],rt[maxn]; void ins(int r,int x) {
rt[r]=++tot;int now=rt[r],pre;if(r) pre=rt[r-1];else pre=0;
for(int i=31;~i;i--) {
int t=x>>i&1;
son[now][t^1]=son[pre][t^1];
now=(son[now][t]=++tot);
pre=son[pre][t];cnt[now]=cnt[pre]+1;
}id[now]=r;
} int query(int l,int r,int x) {
int now=rt[r],pre=0;if(l) pre=rt[l-1];
for(int i=31;~i;i--) {
int t=!(x>>i&1);
if(cnt[son[now][t]]-cnt[son[pre][t]]==0) t^=1;
now=son[now][t],pre=son[pre][t];
}return id[now];
}
}T; struct data {
int l,r,ed,pos; data () {} data (int _l,int _r,int _ed) {
l=_l,r=_r,ed=_ed;
pos=T.query(l,r,a[ed]);
} bool operator < (const data &rhs) const {
return (a[ed]^a[pos])<(a[rhs.ed]^a[rhs.pos]);
}
}; priority_queue<data > s; signed main() {
read(n),read(k);T.ins(0,0);
for(int i=1;i<=n;i++) read(a[i]),a[i]^=a[i-1],T.ins(i,a[i]);
for(int i=1;i<=n;i++) s.push(data(0,i-1,i));
ll ans=0;
for(int i=1;i<=k;i++) {
if(s.empty()) break;
data x=s.top();s.pop();
ans+=a[x.ed]^a[x.pos];
if(1ll*x.pos-1>=1ll*x.l) s.push(data(x.l,x.pos-1,x.ed));
if(x.r>=x.pos+1) s.push(data(x.pos+1,x.r,x.ed));
}write(ans);
return 0;
}