“二逼平衡树”

二逼平衡树

题目要求

题目要求就是写一种数据结构,对于一个长度为$N(N≤50000)$的序列快速对以下的$M(M≤50000)$个操作快速的询问。

1.查询k在区间内的排名

2.查询区间内排名为k的值

3.修改某一位值上的数值

4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)

5.查询k在区间内的后继(后继定义为大于x,且最小的数)


题解

单纯的平衡树对于区间操作肯定会超时。所以我们可以用线段树套平衡树来实现。但是如果对于每个线段树结点都存一颗平衡树的所有信息,由于每个结点最坏需要$n$个结点来维护全序列,所以空间接受不了。于是我们对于平衡树和线段树分开实现,线段树每个节点记录一下对于平衡树的根,注意在平衡树中涉及旋转操作,所以在线段树中递归查询到某一层的时候,在平衡树中查询要把线段树编号也传进去。

下面给出每个操作实现的思路(Splay)

1.查询k在区间内的排名

对于跟普通线段树中一样的递归查询,当且仅当当前区间完全被包含在查询区间中到Splay中查询,$O(nlog^2n)$

2.查询区间内排名为k的值

这个操作不好实现,所以我们采用二分答案,二分这个第k的值,然后在区间中查询这个数的rank,然后对中点进行调整,$O(nlog^3n)$

3.修改某一位值上的数值

对于线段树递归修改,每一层的Splay都考虑这个点的修改,可以删除原来的点再插入现在的点。同时对于修改位置与当前线段树结点维护区间中点的关系,递归修改左儿子或右儿子。$O(nlog^2n)$

4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)

对于跟普通线段树中一样的递归查询,当且仅当当前区间完全被包含在查询区间中到Splay中查询,$O(nlog^2n)$

5.查询k在区间内的后继(后继定义为大于x,且最小的数)

对于跟普通线段树中一样的递归查询,当且仅当当前区间完全被包含在查询区间中到Splay中查询,$O(nlog^2n)$

所以还是很好实现的么。但是我原来写的Splay完美应用了Splay插入后该节点旋转到根的性质,查询操作通过插入删除来实现,导致了空间和时间的大量浪费。

改用另外一种直接在Splay中遍历查找的方法,时空复杂度都基本降低至了原来的$\frac 1 2$。

Code:

version1(slow)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
const int N=5e4+1e3+7,M=N*123,INF=1e8+1e3+7;
struct T{
int ls,rs,l,r;
}t[N*2+1007];
int son[M][2],fa[M],size[M],key[M],num[M],segcnt,splcnt,root[N*2+1],segroot;
int n,m,a[N];
void pushup(int x)
{
size[x]=size[son[x][0]]+size[son[x][1]]+num[x];
}
void rotate(int x)
{
int y=fa[x],z=fa[y],t=son[y][0]==x;
if(z)
son[z][son[z][1]==y]=x;
fa[x]=z;
son[y][!t]=son[x][t];
fa[son[y][!t]]=y;
son[x][t]=y;
fa[y]=x;
pushup(y);
pushup(x);
}
void splay(int now,int x,int s)
{
while(fa[x]!=s)
{
int y=fa[x],z=fa[y];
if(z!=s)
{
if(son[y][0]==x^son[z][0]==y)
rotate(x);
else
rotate(y);
}
rotate(x);
}
if(s==0)
root[now]=x;
}
void splinsert(int now,int &x,int val,int s)
{
if(!x)
{
x=++splcnt;
size[x]=num[x]=1;
fa[x]=s;
son[x][0]=son[x][1]=0;
key[x]=val;
splay(now,x,0);
return;
}
if(key[x]==val)
{
num[x]++;
size[x]++;
splay(now,x,0);
return;
}
splinsert(now,son[x][val>key[x]],val,x);
pushup(x);
}
int get(int now,int val)
{
int x=root[now];
while(x&&val!=key[x])
x=son[x][val>key[x]];
return x;
}
void spldel(int now,int x)
{
x=get(now,x);
if(!x)
return;
splay(now,x,0);
if(num[x]>1)
{
num[x]--;
size[x]--;
return;
}
if(!son[x][0]||!son[x][1])
root[now]=son[x][0]+son[x][1];
else
{
int y=son[x][1];
while(son[y][0])
y=son[y][0];
splay(root[now],y,x);
son[y][0]=son[x][0];
fa[son[y][0]]=y;
root[now]=y;
}
fa[root[now]]=0;
pushup(root[now]);
}
void splchange(int now,int pos,int val)
{
int k=a[pos];
spldel(now,k);
splinsert(now,root[now],val,0);
}
int splrank(int now,int val)
{
splinsert(now,root[now],val,0);
int ret=size[son[root[now]][0]];
spldel(now,val);
return ret;
}
int splpre(int now,int val)
{
splinsert(now,root[now],val,0);
int x=son[root[now]][0];
if(x==0)
{
spldel(now,val);
return -INF;
}
while(x&&son[x][1])
x=son[x][1];
int ret=key[x];
spldel(now,val);
return ret;
}
int splsuc(int now,int val)
{
splinsert(now,root[now],val,0);
int x=son[root[now]][1];
if(x==0)
{
spldel(now,val);
return INF;
}
while(x&&son[x][0])
x=son[x][0];
int ret=key[x];
spldel(now,val);
return ret;
}
int splbuild(int now,int l,int r)
{
for(int i=l;i<=r;i++)
splinsert(now,root[now],a[i],0);
return root[now];
}
int segbuild(int l,int r)
{
int now=++segcnt;
root[now]=splbuild(now,l,r);
t[now].l=l,t[now].r=r;
if(l==r)
return now;
int mid=(l+r)>>1;
t[now].ls=segbuild(l,mid);
t[now].rs=segbuild(mid+1,r);
return now;
}
int segrank(int now,int l,int r,int val)
{
int mid=(t[now].l+t[now].r)>>1;
int ret=0;
if(t[now].l>=l&&t[now].r<=r)
return splrank(now,val);
if(l<=mid)
ret+=segrank(t[now].ls,l,r,val);
if(r>mid)
ret+=segrank(t[now].rs,l,r,val);
return ret;
}
int segpre(int now,int l,int r,int val)
{
int mid=(t[now].l+t[now].r)>>1;
int ret=-INF;
if(t[now].l>=l&&t[now].r<=r)
return max(ret,splpre(now,val));
if(l<=mid)
ret=max(ret,segpre(t[now].ls,l,r,val));
if(r>mid)
ret=max(ret,segpre(t[now].rs,l,r,val));
return ret;
}
int segsuc(int now,int l,int r,int val)
{
int mid=(t[now].l+t[now].r)>>1;
int ret=INF;
if(t[now].l>=l&&t[now].r<=r)
return min(ret,splsuc(now,val));
if(l<=mid)
ret=min(ret,segsuc(t[now].ls,l,r,val));
if(r>mid)
ret=min(ret,segsuc(t[now].rs,l,r,val));
return ret;
}
void segchange(int now,int pos,int val)
{
int mid=(t[now].l+t[now].r)>>1;
splchange(now,pos,val);
if(t[now].l==t[now].r)
return;
if(pos<=mid)
segchange(t[now].ls,pos,val);
if(pos>mid)
segchange(t[now].rs,pos,val);
}
int segquery(int l,int r,int k)
{
int x=0,y=INF;
while(y-x>1)
{
int mid=(x+y)>>1;
if(segrank(segroot,l,r,mid)>=k)
y=mid;
else
x=mid;
}
return x;
}
int main()
{
freopen("psh.in","r",stdin);
freopen("psh.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
segroot=segbuild(1,n);
for(int i=1;i<=m;i++)
{
int op,l,r,val,x;
scanf("%d",&op);
if(op!=3)
scanf("%d%d%d",&l,&r,&val);
else
scanf("%d%d",&x,&val);
if(op==1)
printf("%d\n",segrank(segroot,l,r,val)+1);
if(op==2)
printf("%d\n",segquery(l,r,val));
if(op==3)
segchange(segroot,x,val),a[x]=val;
if(op==4)
printf("%d\n",segpre(segroot,l,r,val));
if(op==5)
printf("%d\n",segsuc(segroot,l,r,val));
}
}

version2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
#include<cstdio>
#include<iostream>
#include<algorithm>
using namespace std;
const int N=5e4+1e3+7,M=N*61,INF=1e8+1e3+7;
struct T{
int ls,rs,l,r;
}t[N*2+1007];
int son[M][2],fa[M],size[M],key[M],num[M],segcnt,splcnt,root[N*2+1],segroot;
int n,m,a[N];
void pushup(int x)
{
size[x]=size[son[x][0]]+size[son[x][1]]+num[x];
}
void rotate(int x)
{
int y=fa[x],z=fa[y],t=son[y][0]==x;
if(z)
son[z][son[z][1]==y]=x;
fa[x]=z;
son[y][!t]=son[x][t];
fa[son[y][!t]]=y;
son[x][t]=y;
fa[y]=x;
pushup(y);
pushup(x);
}
void splay(int now,int x,int s)
{
while(fa[x]!=s)
{
int y=fa[x],z=fa[y];
if(z!=s)
{
if(son[y][0]==x^son[z][0]==y)
rotate(x);
else
rotate(y);
}
rotate(x);
}
if(s==0)
root[now]=x;
}
void splinsert(int now,int &x,int val,int s)
{
if(!x)
{
x=++splcnt;
size[x]=num[x]=1;
fa[x]=s;
son[x][0]=son[x][1]=0;
key[x]=val;
splay(now,x,0);
return;
}
if(key[x]==val)
{
num[x]++;
size[x]++;
splay(now,x,0);
return;
}
splinsert(now,son[x][val>key[x]],val,x);
pushup(x);
}
int get(int now,int val)
{
int x=root[now];
while(x&&val!=key[x])
x=son[x][val>key[x]];
return x;
}
void spldel(int now,int x)
{
x=get(now,x);
if(!x)
return;
splay(now,x,0);
if(num[x]>1)
{
num[x]--;
size[x]--;
return;
}
if(!son[x][0]||!son[x][1])
root[now]=son[x][0]+son[x][1];
else
{
int y=son[x][1];
while(son[y][0])
y=son[y][0];
splay(root[now],y,x);
son[y][0]=son[x][0];
fa[son[y][0]]=y;
root[now]=y;
}
fa[root[now]]=0;
pushup(root[now]);
}
void splchange(int now,int pos,int val)
{
int k=a[pos];
spldel(now,k);
splinsert(now,root[now],val,0);
}
int splrank(int now,int val)
{
int x=root[now],y,ret=0;
while(x)
{
if(key[x]<val)
ret+=size[son[x][0]]+num[x],x=son[x][1];
else
x=son[x][0];
}
return ret;
}
int splpre(int now,int val)
{
int x=root[now],y,ret=-INF;
while(x)
{
if(key[x]>=val)
y=x,x=son[x][0];
else
ret=max(ret,key[x]),y=x,x=son[x][1];
}
return ret;
}
int splsuc(int now,int val)
{
int x=root[now],y,ret=INF;
while(x)
{
if(key[x]>val)
ret=min(ret,key[x]),y=x,x=son[x][0];
else
y=x,x=son[x][1];
}
return ret;
}
int splbuild(int now,int l,int r)
{
for(int i=l;i<=r;i++)
splinsert(now,root[now],a[i],0);
return root[now];
}
int segbuild(int l,int r)
{
int now=++segcnt;
root[now]=splbuild(now,l,r);
t[now].l=l,t[now].r=r;
if(l==r)
return now;
int mid=(l+r)>>1;
t[now].ls=segbuild(l,mid);
t[now].rs=segbuild(mid+1,r);
return now;
}
int segrank(int now,int l,int r,int val)
{
int mid=(t[now].l+t[now].r)>>1;
int ret=0;
if(t[now].l>=l&&t[now].r<=r)
return splrank(now,val);
if(l<=mid)
ret+=segrank(t[now].ls,l,r,val);
if(r>mid)
ret+=segrank(t[now].rs,l,r,val);
return ret;
}
int segpre(int now,int l,int r,int val)
{
int mid=(t[now].l+t[now].r)>>1;
int ret=-INF;
if(t[now].l>=l&&t[now].r<=r)
return max(ret,splpre(now,val));
if(l<=mid)
ret=max(ret,segpre(t[now].ls,l,r,val));
if(r>mid)
ret=max(ret,segpre(t[now].rs,l,r,val));
return ret;
}
int segsuc(int now,int l,int r,int val)
{
int mid=(t[now].l+t[now].r)>>1;
int ret=INF;
if(t[now].l>=l&&t[now].r<=r)
return min(ret,splsuc(now,val));
if(l<=mid)
ret=min(ret,segsuc(t[now].ls,l,r,val));
if(r>mid)
ret=min(ret,segsuc(t[now].rs,l,r,val));
return ret;
}
void segchange(int now,int pos,int val)
{
int mid=(t[now].l+t[now].r)>>1;
splchange(now,pos,val);
if(t[now].l==t[now].r)
return;
if(pos<=mid)
segchange(t[now].ls,pos,val);
if(pos>mid)
segchange(t[now].rs,pos,val);
}
int segquery(int l,int r,int k)
{
int x=0,y=INF;
while(y-x>1)
{
int mid=(x+y)>>1;
if(segrank(segroot,l,r,mid)>=k)
y=mid;
else
x=mid;
}
return x;
}
int main()
{
// freopen("psh.in","r",stdin);
// freopen("psh.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
segroot=segbuild(1,n);
for(int i=1;i<=m;i++)
{
int op,l,r,val,x;
scanf("%d",&op);
if(op!=3)
scanf("%d%d%d",&l,&r,&val);
else
scanf("%d%d",&x,&val);
if(op==1)
printf("%d\n",segrank(segroot,l,r,val)+1);
if(op==2)
printf("%d\n",segquery(l,r,val));
if(op==3)
segchange(segroot,x,val),a[x]=val;
if(op==4)
printf("%d\n",segpre(segroot,l,r,val));
if(op==5)
printf("%d\n",segsuc(segroot,l,r,val));
}
}

(这代码270行,占了这篇博客的大部分篇幅呢。。。)