主席树的区间修改
因为以前搞的主席树基本都忘了,故写一篇帮助记忆。
前置芝士:
主席树
我发现网上的大部分代码码风和我不同,我希望主席树的打法和线段树差不多,所以我找到了一个和线段树差不多的打法。
首先,主席树如果涉及到区间修改,会稍麻烦一些。为了不占用过多空间,我们常常使用一种叫标记永久化的技术。我们不再向下传递标记,相反,我们在查询时带着标记传递。
但是如果我们按照基础的线段树写出下传的代码,会打出:
ll update(ll nl, ll nr, ll l, ll r, ll pos, ll k) {
pos = clone(pos);
t[pos].v += (r - l + 1) * k;
if(nl <= l && r <= nr) {
t[pos].mark += k;
return pos;
}
ll mid = (l + r) >> 1;
if(nl <= mid)
t[pos].ls = update(nl, nr, l, mid, t[pos].ls, k);
if(mid < nr)
t[pos].rs = update(nl, nr, mid + 1, r, t[pos].rs, k);
return pos;
}
这样的话就会出现一个问题,假如我们修改的是下图的橙色区间,然鹅当前的pos为下图的蓝色区间,那么就会多修改下图的红色区间:
也就是说,假如我们要修改 \(\text{[nl, nr]}\),但是当前为 \(\text{[l,r]}\) ,为了保证不出现上图情况,我们可以修改 \(\text{[max(nl,l),min(nr,r)]}\)(感性理解)
那么我们只需要修改得出:
ll update(ll nl, ll nr, ll l, ll r, ll pos, ll k) {
pos = clone(pos);
t[pos].v += (min(nr, r) - max(nl, l) + 1) * k; //⭐⭐
if(nl <= l && r <= nr) {
t[pos].mark += k;
return pos;
}
ll mid = (l + r) >> 1;
if(nl <= mid)
t[pos].ls = update(nl, nr, l, mid, t[pos].ls, k);
if(mid < nr)
t[pos].rs = update(nl, nr, mid + 1, r, t[pos].rs, k);
return pos;
}
对的,只修改了赋值部分。
那么查询就是普通的写法:
ll query(ll nl, ll nr, ll l, ll r, ll pos, ll mark) {
if(nl <= l && r <= nr) {
return t[pos].v + (r - l + 1) * mark;
}
ll mid = (l + r) >> 1;
ll res = 0;
if(nl <= mid)
res += query(nl, nr, l, mid, t[pos].ls, mark + t[pos].mark);
if(mid < nr)
res += query(nl, nr, mid + 1, r, t[pos].rs, mark + t[pos].mark);
return res;
}
看,一点也不难写。
最后我贴出例题和代码,大家可以去打一下这道题(洛谷如果交不了可以试试vjudge):
TTM - To the moon
题面翻译
一个长度为 \(N\) 的数组 \(\{A\}\),\(4\) 种操作 :
C l r d
:区间 \([l,r]\) 中的数都加 \(d\) ,同时当前的时间戳加 \(1\)。
Q l r
:查询当前时间戳区间 \([l,r]\) 中所有数的和 。
H l r t
:查询时间戳 \(t\) 区间 \([l,r]\) 的和 。
B t
:将当前时间戳置为 \(t\) 。所有操作均合法 。
ps:刚开始时时间戳为 \(0\)
输入格式,一行 \(N\) 和 \(M\),接下来 \(M\) 行每行一个操作
输出格式:对每个查询输出一行表示答案
数据保证:\(1\le N,M\le 10^5\),\(|A_i|\le 10^9\),\(1\le l \le r \le N\),\(|d|\le10^4\)。在刚开始没有进行操作的情况下时间戳为 \(0\),且保证
B
操作不会访问到未来的时间戳。由 @bztMinamoto @yzy1 提供翻译
题目描述
输入格式
n m A1 A2 ... An ... (here following the m operations. )
输出格式
... (for each query, simply print the result. )
样例 #1
样例输入 #1
10 5 1 2 3 4 5 6 7 8 9 10 Q 4 4 Q 1 10 Q 2 4 C 3 6 3 Q 2 4
样例输出 #1
4 55 9 15
样例 #2
样例输入 #2
2 4 0 0 C 1 1 1 C 2 2 -1 Q 1 2 H 1 2 1
样例输出 #2
0 1
代码:
#include <cstdio>
#include <algorithm>
#define ll long long
#define N 200000
using namespace std;
ll n, m;
ll a[N + 10];
ll rt[N + 10];
ll time = 0;
struct node {
ll v, ls, rs, mark;
} t[(N << 5) + 10];
ll tot;
ll build(ll l, ll r, ll pos) {
pos = ++tot;
if(l == r) {
t[pos].v = a[l];
return pos;
}
ll mid = (l + r) >> 1;
t[pos].ls = build(l, mid, pos);
t[pos].rs = build(mid + 1, r, pos);
t[pos].v = t[t[pos].ls].v + t[t[pos].rs].v;
return pos;
}
ll clone(ll pos) {
t[++tot] = t[pos];
return tot;
}
ll update(ll nl, ll nr, ll l, ll r, ll pos, ll k) {
pos = clone(pos);
t[pos].v += (min(nr, r) - max(nl, l) + 1) * k;
if(nl <= l && r <= nr) {
t[pos].mark += k;
return pos;
}
ll mid = (l + r) >> 1;
if(nl <= mid)
t[pos].ls = update(nl, nr, l, mid, t[pos].ls, k);
if(mid < nr)
t[pos].rs = update(nl, nr, mid + 1, r, t[pos].rs, k);
return pos;
}
ll query(ll nl, ll nr, ll l, ll r, ll pos, ll mark) {
if(nl <= l && r <= nr) {
return t[pos].v + (r - l + 1) * mark;
}
ll mid = (l + r) >> 1;
ll res = 0;
if(nl <= mid)
res += query(nl, nr, l, mid, t[pos].ls, mark + t[pos].mark);
if(mid < nr)
res += query(nl, nr, mid + 1, r, t[pos].rs, mark + t[pos].mark);
return res;
}
int main() {
scanf("%lld %lld", &n, &m);
for(ll i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
}
rt[0] = build(1, n, 0);
for(ll i = 1; i <= m; i++) {
char op[5];
ll l, r, d;
scanf("%s", op);
if(op[0] == 'C') {
scanf("%lld %lld %lld", &l, &r, &d);
rt[time+1] = update(l, r, 1, n, rt[time], d);
time++;
}
else if(op[0] == 'Q') {
scanf("%lld %lld", &l, &r);
printf("%lld\n", query(l, r, 1, n, rt[time], 0));
}
else if(op[0] == 'H') {
scanf("%lld %lld %lld", &l, &r, &d);
printf("%lld\n", query(l, r, 1, n, rt[d], 0));
}
else if(op[0] == 'B') {
scanf("%lld", &time);
}
}
}