线段树

线段树是一种二叉树,也就是对于一个线段,我们会用一个二叉树来表示。

性质:节点 i 的权值 = 她的左儿子权值 + 她的右儿子权值。

1. 建树

根据这个思路,我们就可以建树了,我们设一个结构体 treetree[i].ltree[i].r 分别表示这个点代表的线段的左右下标,tree[i].sum 表示这个节点表示的线段和。

我们知道,一颗从1开始编号的二叉树,结点 i 的左儿子和右儿子编号分别是 2×i 和 2×i+1。

再根据刚才的性质,得到式子:tree[ i ].sum = tree[i∗2].sum + tree [i∗2+1].sum ,就可以建一颗线段树了!代码如下(这里以区间求和的查询为例):

1
2
3
4
5
6
7
8
9
10
11
12
13
void build(int i,int l,int r){//递归建树
tree[i].l=l;tree[i].r=r;
if(l==r){//如果这个节点是叶子节点
tree[i].sum=input[l];
return ;
}
//二分建树
int mid=(l+r)/2;
build(i*2,l,mid);//分别构造左子树和右子树
build(i*2+1,mid+1,r);
tree[i].sum=tree[i*2].sum+tree[i*2+1].sum;//刚才我们发现的性质return ;
}

2. 简单(无pushdown)的线段树

2.1 单点修改,区间查询

线段树的查询方法(范围查询):

  1. 如果这个区间被完全包括在目标区间里面,直接返回这个区间的值
  2. 如果这个区间的左儿子和目标区间有交集,那么搜索左儿子
  3. 如果这个区间的右儿子和目标区间有交集,那么搜索右儿子

这里以区间求和的查询为例:

1
2
3
4
5
6
7
8
9
10
int search(int i,int l,int r){
if(tree[i].l>=l && tree[i].r<=r)//如果这个区间被完全包括在目标区间里面,直接返回这个区间的值
return tree[i].sum;
if(tree[i].r<l || tree[i].l>r) return 0;//如果这个区间和目标区间毫不相干,返回0
int s=0;
if(tree[i*2].r>=l) s+=search(i*2,l,r);//如果这个区间的左儿子和目标区间又交集,那么搜索左儿子
if(tree[i*2+1].l<=r) s+=search(i*2+1,l,r);//如果这个区间的右儿子和目标区间又交集,那么搜索右儿子
return s;
}

注意几个if条件和传参:

  1. 传参:节点编号+左范围+右范围
  2. 完全包含->直接返回sum;
  3. 毫不相干->返回0;
  4. 左儿子/右儿子有交集->递归搜索下去

单点更新:

怎么修改这个区间的单点,其实这个相对简单很多,你要把区间的第dis位加上k,那么你从根节点开始,看这个dis是在左儿子还是在右儿子,在哪往哪跑,

然后返回的时候,还是按照tree[i].sum=tree[i*2].sum+tree[i*2+1].sum的原则,更新所有路过的点。

整个过程也是递归的

1
2
3
4
5
6
7
8
9
10
inline void add(int i,int dis,int k){
if(tree[i].l==tree[i].r){//如果是叶子节点,那么说明找到了
tree[i].sum+=k;
return ;
}
if(dis<=tree[i*2].r) add(i*2,dis,k);//在哪往哪跑
else add(i*2+1,dis,k);
tree[i].sum=tree[i*2].sum+tree[i*2+1].sum;//返回更新
return ;
}

2.2 区间修改,单点查询

区间修改和单点查询,我们的思路就变为:如果把这个区间加上 k ,相当于把这个区间涂上一个 k 的标记,然后单点查询的时候,就从上跑到下,把沿路的标记加起来就好。

这里面给区间贴标记的方式与上面的区间查找类似,原则还是那三条,只不过第一条:如果这个区间被完全包括在目标区间里面,直接将这个区间的值+k。(lazy标记)

具体做法很像,这里贴上代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
void modify(int p, int l, int r, int k) 
{
if(tr[p].l >= l && tr[p].r <= r) {
tr[p].num += k;
return ;
}
int mid = tr[p].l + tr[p].r >> 1;
if(l <= mid) modify(p << 1, l, r, k);
if(r > mid) modify(p << 1 | 1, l, r, k);
}
/*
inline void add(int i,int l,int r,int k){
if(tree[i].l>=l && tree[i].r<=r){//如果这个区间被完全包括在目标区间里面,讲这个区间标记k
tree[i].sum+=k;
return ;
}
if(tree[i*2].r>=l)
add(i*2,l,r,k);
if(tree[i*2+1].l<=r)
add(i*2+1,l,r,k);
}
*/

单点查询了,这个更好理解了,就是dis在哪往哪跑,把路径上所有的标记(加上lazy标记)加上就好了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
void query(int p, int x)
{
ans += tr[p].num;//一路加起来
if(tr[p].l == tr[p].r) return;
int mid = tr[p].l + tr[p].r >> 1;
if(x <= mid) query(p << 1, x);
else query(p << 1 | 1, x);
}
/*
void search(int i,int dis){
ans+=tree[i].sum;//一路加起来
if(tree[i].l==tree[i].r)
return ;
if(dis<=tree[i*2].r)
search(i*2,dis);
if(dis>=tree[i*2+1].l)
search(i*2+1,dis);
}
*/

完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 7;

int n, m, s, t;
int ans;
int a[maxn];
struct segment_tree
{
struct node
{
int l, r;
int num;
}tr[maxn * 4];

void build(int p, int l, int r)
{
tr[p] = {l, r, 0};
if(l == r) {
tr[p].num = a[l];
return ;
}
int mid = (l + r)/2;
build(p*2, l, mid);
build(p*2+1, mid + 1, r);
}
//区间修改
void modify(int p, int l, int r, int k)
{
//完全在区间里面
if(tr[p].l >= l && tr[p].r <= r) {
tr[p].num += k;
return ;
}
int mid = (tr[p].l + tr[p].r)/2;
if(l <= mid) modify(p*2, l, r, k);
if(r > mid) modify(p*2+1, l, r, k);
}
//单点查询
void query(int p, int x)
{
ans += tr[p].num;
if(tr[p].l == tr[p].r) return;
int mid = (tr[p].l + tr[p].r)/2;
if(x <= mid) query(p*2, x);
else query(p*2+1, x);
}
}ST;

int main()
{
cin >> n >> m;
for (int i = 1; i <= n; ++ i) {
scanf("%d", &a[i]);
}
ST.build(1, 1, n);
for (int i = 1; i <= m; ++ i) {
int c;
scanf("%d", &c);
if(c == 1) {
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
ST.modify(1, x, y, z);
}
else {
ans = 0;
int x;
scanf("%d", &x);
ST.query(1, x);
printf("%d\n", ans);
}
}
return 0;
}
/*
int main()
{
n = 100;
for (int i = 1; i <= n; ++ i) {
a[i] = i;
}
ST.build(1, 1, n);
m = 10;
for (int i = 1; i <= m; ++ i) {
int l = 1, r = 100;
ST.modify(1, l, r, 10000);
ans = 0;
// query(p, x), p = 1, x 为想要查询的点的下标
ST.query(1, 50); // 单点查询 下标为 50 的点的值,ans = 50 + 10000 * i
cout << i << " " << ans << '\n';
ans = 0;
ST.query(1, 100); // 单点查询 ans = 100 + 10000 * i
cout << i << " " << ans << '\n';
}
return 0;
}
*/