———————————————————————————————————————————————————————————————————————————————————————————
概要
本文介绍了线段树的基本原理和实现,举了一些典型的可以用线段树解决的问题并进行分析和解答。
—BY Ahui2667d(张云辉)
本文代码适用于c++
本文代码中的通用宏定义
#define ll long long #define ls(p) (p<<1) #define rs(p) (p<<1|1) #define mid ((l+r)>>1)
线段树的简单理解及实现
问题的引入:
下面考虑这样一个问题:给定一个数组a[N], 多次查询 [ l , r ] 的区间和。
当然我们可以用树状数组实现快速查询,但是如果我们增加多次对 [ l, r ] 的区间修改呢?树状数组的单点修改会使复杂度过高,从而TLE。
为了解决对区间的修改与查询的操作,我们引入线段树。
线段树的原理:
简单说,线段树就是通过将长度非一的区间二分递归以维护区间特征值的方法。
如下图,对于给定的数组a[1]--a[5] :
-
令d [ i ] 存储某一区间的特征值(这里以区间和为例子),d[1]存储的是 [ 1 , 5 ] 的区间和,d[2]存储的是 [ 1 , 3 ] 的区间和,d[3]存储的是 [ 4 , 5 ] 的区间和。
-
在图中的树中d[i]的左右子节点分别为 d [ i * 2 ] 和 d [ i * 2 + 1 ] 。
-
d[i]存储 [ l , r ] 的特征值,那么左右节点分别存储 [ l , mid ] , [ mid+1 ,r ]的特征值( mid=(l+r)/2 )

代码实现
注意:由于线段树的存储方式,d的长度是a的长度的至少四倍
上传
void push_up(int p)//根据左右子节点更新d[p];
{d[p] = d[ls(p)] + d[rs(p)];
}
建树
void build(int p, int l, int r)//建立[l,r]的线段树,当前访问的节点为p
{t[p] = 0;if (l == r) return d[p] = a[l], void();build(ls(p), l, mid), build(rs(p), mid + 1, r);push_up(p);
}
查询
ll ask(int p, int l, int r, int L, int R)//查询[L,R]的区间和,当前访问节点p,访问区间[l,r];
{if (l > R || r < L)return 0;if (L <= l && r <= R)return d[p];{//push_down(p, l, r);暂时先不管它return ask(ls(p), l, mid, L, R) + ask(rs(p), mid + 1, r, L, R);}
}
修改
如果我们每一次对区间的修改都进行到叶节点的话,时间复杂度过高。
我们设法当且仅当查询/访问到某个节点时,我们才对它修改。
//利用t[p]来记录t[p]是否作了修改
void add(int p, int l, int r, ll x)//对节点p进行修改:每个元素加x;当前访问[l,r]区间;
{d[p] = (d[p] + x * (r - l + 1));t[p] = t[p] + x;
}
void push_down(int p, int l, int r)//下传,对d[p]的左右子节点进行修改;当前访问区间[l,r];
{if (!t[p]) return ;//t[p]标记为0,则不对d[p]作修改;add(ls(p), l, mid, t[p]), add(rs(p), mid + 1, r, t[p]);//分别修改d[p]的左右子节点;t[p] = 0;//清空t[p]标记;
}void add(int p, int l, int r, int L, int R, ll x)//当前访问节点p,访问区间[l,r];对[L,R]区间中每个元素加x;
{if (l > R || r < L)//访问区间[l,r]与修改区间[L,R]没有交集;return ;if (L <= l && r <= R)//访问区间[l,r]包含于修改区间[L,R];{add(p, l, r, x);return;}else{push_down(p, l, r);//先对左右子节点进行修改add(ls(p), l, mid, L, R, x), add(rs(p), mid + 1, r, L, R, x);push_up(p);}
}
现在我们知道为什么在ask函数中设置push_down了。
例题
均来自Luogu
———————————————————————————————————————————————————————————————————————————————————————————
P3372 【模板】线段树 1
题目描述
如题,已知一个数列 \(\{a_i\}\),你需要进行下面两种操作:
- 将某区间每一个数加上 \(k\)。
- 求出某区间每一个数的和。
输入格式
第一行包含两个整数 \(n, m\),分别表示该数列数字的个数和操作的总个数。
第二行包含 \(n\) 个用空格分隔的整数 \(a_i\),其中第 \(i\) 个数字表示数列第 \(i\) 项的初始值。
接下来 \(m\) 行每行包含 \(3\) 或 \(4\) 个整数,表示一个操作,具体如下:
1 x y k:将区间 \([x, y]\) 内每个数加上 \(k\)。2 x y:输出区间 \([x, y]\) 内每个数的和。
输出格式
输出包含若干行整数,即为所有操作 2 的结果。
输入输出样例 #1
输入 #1
5 5
1 5 4 2 3
2 2 4
1 2 3 2
2 3 4
1 1 5 1
2 1 4
输出 #1
11
8
20
说明/提示
对于 \(15\%\) 的数据:\(n \le 8\),\(m \le 10\)。
对于 \(35\%\) 的数据:\(n \le {10}^3\),\(m \le {10}^4\)。
对于 \(100\%\) 的数据:\(1 \le n, m \le {10}^5\),\(a_i,k\) 为正数,且任意时刻数列的和不超过 \(2\times 10^{18}\)。
【样例解释】

———————————————————————————————————————————————————————————————————————————————————————————
分析
直接由上面的代码即可得出。
代码实现
#include<bits/stdc++.h>
#define ll long long
#define mid ((l+r)>>1)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
const int maxn = 1e5+10;
ll a[maxn], d[maxn << 2], t[maxn << 2];
int n, m;
using namespace std;
void push_up(int p)//上传,根据左右子节点更新d[p];
{d[p] = d[ls(p)] + d[rs(p)];
}
void build(int p, int l, int r)//建立[l,r]的线段树,当前访问的节点为p
{t[p] = 0;if (l == r) return d[p] = a[l], void();build(ls(p), l, mid), build(rs(p), mid + 1, r);push_up(p);
}
//利用t[p]来记录t[p]是否作了修改
void add(int p, int l, int r, ll x)//对节点p进行修改:每个元素加x;当前访问[l,r]区间;
{d[p] = (d[p] + x * (r - l + 1));t[p] = t[p] + x;
}
void push_down(int p, int l, int r)//下传,对d[p]的左右子节点进行修改;当前访问区间[l,r];
{if (!t[p]) return ;//t[p]标记为0,则不对d[p]作修改;add(ls(p), l, mid, t[p]), add(rs(p), mid + 1, r, t[p]);//分别修改d[p]的左右子节点;t[p] = 0;//清空t[p]标记;
}void add(int p, int l, int r, int L, int R, ll x)//当前访问节点p,访问区间[l,r];对[L,R]区间中每个元素加x;
{if (l > R || r < L)//访问区间[l,r]与修改区间[L,R]没有交集;return ;if (L <= l && r <= R)//访问区间[l,r]包含于修改区间[L,R];{add(p, l, r, x);return;}else{push_down(p, l, r);//先对左右子节点进行修改add(ls(p), l, mid, L, R, x), add(rs(p), mid + 1, r, L, R, x);push_up(p);}
}
ll ask(int p, int l, int r, int L, int R)//查询[L,R]的区间和,当前访问节点p,访问区间[l,r];
{if (l > R || r < L)return 0;if (L <= l && r <= R)return d[p];{push_down(p, l, r);暂时先不管它return ask(ls(p), l, mid, L, R) + ask(rs(p), mid + 1, r, L, R);}
}
int main()
{ios::sync_with_stdio(0);cin >> n >> m;for (int i = 1; i <= n; i++)cin >> a[i];build(1, 1, n);while (m--){ll op, x, y, k;cin >> op >> x >> y;if (op == 1){cin >> k;add(1, 1, n, x, y, k);}elsecout << ask(1, 1, n, x, y) << "\n";}return 0;
}
———————————————————————————————————————————————————————————————————————————————————————————
P3373 【模板】线段树 2
题目描述
如题,已知一个数列 \(a\),你需要进行下面三种操作:
- 将某区间每一个数乘上 \(x\);
- 将某区间每一个数加上 \(x\);
- 求出某区间每一个数的和。
输入格式
第一行包含三个整数 \(n,q,m\),分别表示该数列数字的个数、操作的总个数和模数。
第二行包含 \(n\) 个用空格分隔的整数,其中第 \(i\) 个数字表示数列第 \(i\) 项的初始值 \(a_i\)。
接下来 \(q\) 行每行包含若干个整数,表示一个操作,具体如下:
操作 \(1\): 格式:1 x y k 含义:将区间 \([x,y]\) 内每个数乘上 \(k\)。
操作 \(2\): 格式:2 x y k 含义:将区间 \([x,y]\) 内每个数加上 \(k\)。
操作 \(3\): 格式:3 x y 含义:输出区间 \([x,y]\) 内每个数的和对 \(m\) 取模所得的结果。
输出格式
输出包含若干行整数,即为所有操作 \(3\) 的结果。
输入输出样例 #1
输入 #1
5 5 38
1 5 4 2 3
2 1 4 1
3 2 5
1 2 4 2
2 3 5 5
3 1 4
输出 #1
17
2
说明/提示
【数据范围】
对于 \(30\%\) 的数据:\(n \le 8\),\(q \le 10\)。
对于 \(70\%\) 的数据:$n \le 10^3 \(,\)q \le 10^4$。
对于 \(100\%\) 的数据:\(1 \le n \le 10^5\),\(1 \le q \le 10^5,1\le a_i,k\le 10^4\)。
除样例外,\(m = 571373\)。
(数据已经过加强 _)
样例说明:

故输出应为 \(17\)、\(2\)(\(40 \bmod 38 = 2\))。
———————————————————————————————————————————————————————————————————————————————————————————
分析
-
本题与上题增加了对区间乘法修改的需求
-
考虑如何实现区间乘法:显然 [ l , r ] 中每一个元素乘k,那么区间和乘k;
-
由于本题需要实现区间加法的修改,我们还需要考虑下传时加法和乘法的顺序
容易想到,如果修改操作中先乘再加,那么代码中先乘再加
如果修改操作中先加再乘,我们考虑和上面代码中的顺序保持一致:不难发现 (d+x)y=dy+xy
因此我们将乘法标记初始化为1,加法标记初始化为0;
先让加法标记和区间和乘上乘法标记,再让区间和与加法标记相加即可。
代码实现
只需综合分析,对上一题代码稍作修改
#include<bits/stdc++.h>
#define ll long long
#define mid ((l+r)>>1)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
const int maxn = 1e5+10;
ll a[maxn], d[maxn << 2], t[maxn << 2], t1[maxn << 2];
int n, m, pp ;
using namespace std;
void push_up(int p)
{d[p] = (d[ls(p)] + d[rs(p)]) % pp;
}
void build(int p, int l, int r)
{t[p] = 0;t1[p] = 1;if (l == r) return d[p] = a[l], void();build(ls(p), l, mid), build(rs(p), mid + 1, r);push_up(p);
}
void add(int p, int l, int r, ll x)
{d[p] = (d[p] + x * ((r - l + 1) % pp)) % pp;t[p] = (t[p] + x) % pp;
}
void mul(int p, int l, int r, ll x)//区间乘法
{d[p] = (d[p] * (x % pp)) % pp;t[p] = ( t[p] * (x % pp )) % pp;t1[p] = (t1[p] * (x % pp )) % pp;
}
void push_down(int p, int l, int r)
{if ((t1[p] == 1) && !t[p]) return ;mul(ls(p), l, mid, t1[p]), mul(rs(p), mid + 1, r, t1[p]);//下传乘法标记add(ls(p), l, mid, t[p]), add(rs(p), mid + 1, r, t[p]);//下传加法标记t[p] = 0;t1[p] = 1;
}
void add(int p, int l, int r, int L, int R, ll x)
{if (l > R || r < L)return ;if (L <= l && r <= R){add(p, l, r, x);return;}else{push_down(p, l, r);add(ls(p), l, mid, L, R, x), add(rs(p), mid + 1, r, L, R, x);push_up(p);}
}
void mul(int p, int l, int r, int L, int R, ll x)//区间乘法
{if (l > R || r < L)return ;if (L <= l && r <= R){mul(p, l, r, x);return;}else{push_down(p, l, r);mul(ls(p), l, mid, L, R, x), mul(rs(p), mid + 1, r, L, R, x);push_up(p);}
}
ll ask(int p, int l, int r, int L, int R)
{if (l > R || r < L)return 0;if (L <= l && r <= R)return d[p];{push_down(p, l, r);return (ask(ls(p), l, mid, L, R) + ask(rs(p), mid + 1, r, L, R)) % pp;}
}
int main()
{ios::sync_with_stdio(0);cin >> n >> m >> pp;for (int i = 1; i <= n; i++)cin >> a[i];build(1, 1, n);while (m--){ll op, x, y, k;cin >> op >> x >> y;if (op == 1){cin >> k;mul(1, 1, n, x, y, k);}else if (op == 2){cin >> k;add(1, 1, n, x, y, k);}elsecout << ask(1, 1, n, x, y) % pp << "\n";}return 0;
}
———————————————————————————————————————————————————————————————————————————————————————————
P1471 方差
题目背景
滚粗了的 HansBug 在收拾旧数学书,然而他发现了什么奇妙的东西。
题目描述
蒟蒻 HansBug 在一本数学书里面发现了一个神奇的数列,包含 \(N\) 个实数。他想算算这个数列的平均数和方差。
输入格式
第一行包含两个正整数 \(N,M\),分别表示数列中实数的个数和操作的个数。
第二行包含 \(N\) 个实数,其中第 \(i\) 个实数表示数列的第 \(i\) 项。
接下来 \(M\) 行,每行为一条操作,格式为以下三种之一:
操作 \(1\):1 x y k ,表示将第 \(x\) 到第 \(y\) 项每项加上 \(k\),\(k\) 为一实数。
操作 \(2\):2 x y ,表示求出第 \(x\) 到第 \(y\) 项这一子数列的平均数。
操作 \(3\):3 x y ,表示求出第 \(x\) 到第 \(y\) 项这一子数列的方差。
输出格式
输出包含若干行,每行为一个实数,即依次为每一次操作 \(2\) 或操作 \(3\) 所得的结果(所有结果四舍五入保留 \(4\) 位小数)。
输入输出样例 #1
输入 #1
5 5
1 5 4 2 3
2 1 4
3 1 5
1 1 1 1
1 2 2 -1
3 1 5
输出 #1
3.0000
2.0000
0.8000
说明/提示
关于方差:对于一个有 \(n\) 项的数列 \(A\),其方差 \(s^2\) 定义如下: