浅谈可持续化线段树

引言

可持续化线段树(又叫主席树、函数式线段树),顾名思义就是保存线段树的所有历史版本,并且利用他们共同的数据来减少时间和空间的消耗
相比普通的线段树维护当前节点对应的区间的信息,可持续化线段树能够记录每次修改后的线段树,可以解决区间第k个数大小的问题。

主席树能够保存线段树的所有历史版本,这肯定不会是每一个线段树都存储下来(这样一定会MLE的),而是在每次修改的时候只记录修改的结点,没有修改的结点还用原来的线段树里面的结点,这样在线段树中修改某一个值得时候,只需要新增logn个结点来记录这修改了的logn个结点,其他的结点都是不变的,充分利用其中的共有数据。主席树中的每一个结点保存的是一个线段树,维护的区间相同, 结构相同,只有保存的信息不同,这样主席树中的结点就具备了加减性。

实现

建树

原始的树可以看成是一课空树tree[0],树中的任何结点的左右结点都是这个空结点,载入原始数据的过程可以看成是第一个历史版本,其他的过程和普通的线段树相同,临界判断,向下更新。

更新

在原来的线段树上进行更改,所修改的每一个结点都是充分分配空间的新结点,同时将父节点指向当前的结点,再用root数组存储当前线段树的根节点。这里很巧妙地用到了引用操作(&),例如update(tree[id].l,l,r,v),同时update函数的变量表为update(int &id,int l,int r,int v),函数体的开头就是tree[++cnt]=tree[id];id=cnt; 这样巧妙地操作顺便将父节点的左右孩子结点指针也指向了该结点,这样就跟普通的线段树相差无几了

查询

查询的操作基本上跟普通的线段树一样,就是我们可以随便在哪一个历史版本中查询,比如query(root[k],l,r) 表示在第k个历史版本中查询

POJ 2104 区间第k大问题

题意

给出n个数,查询区间[l,r]中第k大的数是多少

题解

  • 首先数据范围比较大,考虑进行离散化,即对所给的数列进行排序,利用每个数在排序后的数组中的下标进行处理
  • 使用可持续化线段树,建立n个线段树,从第一个线段树开始,后面每一个线段树多一个数的信息
  • 实现查询,因为所建立的这n个线段树维护的区间相同,结构相同,将查询区间两个端点的线段树相减就可以得到[l,r]中的信息,再使用二分法递归找到查询的数在整个数组中的排位

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>

using namespace std;

const int MAX=100005;
int nums[MAX],sorted[MAX],root[MAX];
int cnt; //记录主席树节点编号

struct segment
{
int sum,l,r;
}tree[MAX<<5];

int creatnode(int sum,int l,int r)
{
int ro=++cnt;
tree[ro].sum=sum;
tree[ro].l=l;
tree[ro].r=r;
return ro;
}

void Insert(int &ro,int pre,int pos,int l,int r)
{
ro=creatnode(tree[pre].sum+1,tree[pre].l,tree[pre].r); //创建结点,更改区间所维护值
if(l==r)return;
int m=(l+r)>>1;
if(pos<=m)
Insert(tree[ro].l,tree[pre].l,pos,l,m);
else
Insert(tree[ro].r,tree[pre].r,pos,m+1,r);
}

int query(int S,int E,int l,int r,int k)
{
if(l==r)return l;
int sum_l=tree[tree[E].l].sum-tree[tree[S].l].sum; //表示区间[l,r]的左子区间的个数和
int m=(l+r)>>1;
if(k<=sum_l)
return query(tree[S].l,tree[E].l,l,m,k); //二分递归查找
else
return query(tree[S].r,tree[E].r,m+1,r,k-sum_l);
}

void solve()
{
int n,m,num,pos;
while(scanf("%d %d",&n,&m)!=EOF)
{
cnt=0;
root[0]=0; //一开始的线段树是一个空的线段树
for(int i=1;i<=n;i++)
{
scanf("%d",&nums[i]);
sorted[i]=nums[i];
}
sort(sorted+1,sorted+1+n);
num=unique(sorted+1,sorted+1+n)-(sorted+1); //离散化,数组中的数用排序后的下表表示
for(int i=1;i<=n;i++)
{
pos=lower_bound(sorted+1,sorted+1+num,nums[i])-sorted;
Insert(root[i],root[i-1],pos,1,num); //创建主席树中的一个又一个结点
}
int l,r,k;
for(int i=0;i<m;i++)
{
scanf("%d %d %d",&l,&r,&k);
pos=query(root[l-1],root[r],1,num,k);
printf("%d\n",sorted[pos]);
}

}
}

int main()
{
freopen("input.txt","r",stdin);
solve();
return 0;
}

HDU 4348 区间增加+可持续化

题意

给出一个序列,并为每次插入操作添加时间点,进行一下操作:查询区间[l,r]的和;查询时间点k下的[l,r]区间和;更改时间点;为区间[l,r]添加数

题解

  • 为每一个历史版本创建一颗线段树,同时充分利用共有数据,只增加有改变的结点
  • root数组记录每一个历史版本线段树的根节点信息
  • lazy思想,在查询的时候再进行向下的更改

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>

using namespace std;

const int maxn=100005;
typedef long long ll;

struct segnode
{
int lt,rt,lv,rv; //左右孩子结点的位置,维护的区间值
ll sum,lazy; //lazy思想,在查询的时候再进行修改
}tree[maxn<<4];

int cnt,now,a[maxn],root[maxn];

void build(int &id,int l,int r)
{
tree[++cnt]=tree[id];
id=cnt;
tree[id].lv=l;
tree[id].rv=r;
if(l==r)
{
tree[id].sum=a[l];
return;
}
int mid=(l+r)>>1;
build(tree[id].lt,l,mid); //引用的妙处
build(tree[id].rt,mid+1,r);
tree[id].sum=tree[tree[id].lt].sum+tree[tree[id].rt].sum;
return;
}

void update(int &id,int l,int r,int v)
{
tree[++cnt]=tree[id];
id=cnt;
tree[id].sum+=(min(r,tree[id].rv)-max(l,tree[id].lv)+1)*v;
if(l<=tree[id].lv && tree[id].rv<=r)
{
if(tree[id].lv!=tree[id].rv)
tree[id].lazy+=v;
return;
}
int mid=(tree[id].lv+tree[id].rv)>>1;
if(r<=mid)
update(tree[id].lt,l,r,v); //引用的妙用
else if(l>mid)
update(tree[id].rt,l,r,v);
else
{
update(tree[id].lt,l,r,v);
update(tree[id].rt,l,r,v);
}

}

ll query(int id,int l,int r)
{
if(l<=tree[id].lv && tree[id].rv<=r)
return tree[id].sum;
ll ret=(min(r,tree[id].rv)-max(l,tree[id].lv)+1)*tree[id].lazy;
int mid=(tree[id].lv+tree[id].rv)>>1;
if(r<=mid)
ret+=query(tree[id].lt,l,r);
else if(l>mid)
ret+=query(tree[id].rt,l,r);
else
ret+=query(tree[id].lt,l,r)+query(tree[id].rt,l,r);
return ret;
}

void solve()
{
int n,m,l,r,t,d;
char op;
bool hh=0;
while(scanf("%d %d",&n,&m)!=EOF)
{
if(hh)
printf("\n");
else
hh=1;
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
now=cnt=0;
build(root[0],1,n);
for(int i=0;i<m;i++)
{
scanf(" %c",&op); //注意前面要加一个空格
if(op=='C')
{
now++; //时间增加
scanf("%d %d %d",&l,&r,&d);
update(root[now]=root[now-1],l,r,d);
}
if(op=='Q')
{
scanf("%d %d",&l,&r);
printf("%lld\n",query(root[now],l,r));
}
if(op=='H')
{
scanf("%d %d %d",&l,&r,&t);
printf("%lld\n",query(root[t],l,r));
}
if(op=='B')
{
scanf("%d",&now);
cnt=root[now+1]-1; //释放结点
}
}
}
}

int main()
{
freopen("input.txt","r",stdin);
solve();
return 0;
}

参考