线段树解题技巧

Oracle's Blog / 2023-07-26 / 原文

前言

线段树是一种在 \(\log\) 时间内维护区间信息的数据结构,其维护的信息具有区间可加性。

区间可加性,也就是由区间 \(A\) 和区间 \(B\),可以推出 \(A\cup B\)

上面说到的区间,指的是区间内维护的信息。

如区间和,区间平方和,区间最值,区间最大子段,区间最长连续子段,这类问题就是具有区间可加性的。

关于线段树维护的题目,分为两类,一类是好维护的,一类是不好维护的,体现在修改与查询的关系并不大。下面分这两类进行分析。

好维护

好维护的信息通常是由修改可以推出查询,比如修改是将一个区间加上某个数,查询是查区间和,这时可以直接由修改推出查询。

P3373 【模板】线段树 2

比单纯的区间加稍微复杂一点。

这题显然是好维护的,对于一个区间加上一个数,很典,乘上一个数,考虑添加一个乘法懒标记。记 \(tag1\) 为加法标记,\(tag2\) 为乘法标记。

这时我们要考虑,加法标记和乘法标记的优先级。对于一个运算:

\[((x+1)\times 4+6)\times 7 \]

不难发现,\(1\) 乘了 \(4\times 7\),但是 \(6\) 只乘了 \(7\),这提示我们不能直接将 \(tag1\) 累加至 \(sum\),再用 \(tag2\) 去乘,此时我们将这个柿子的顺序变换一下:

\[x\times 4\times 7+1\times 4\times 7+6\times 7 \]

这启示我们加法标记 \(tag1\) 存的实际是 \(1\times 4\times 7+6\times 7\)\(tag2\) 存的是 \(4\times 7\),在最后计算 \(sum\) 时,采取先乘后加的方法。对于维护 \(tag\) 也是类似。

点击查看代码
#include<bits/stdc++.h>
using namespace std;

typedef long long LL;
typedef unsigned long long ULL;
LL read() {
    LL sum=0,flag=1; char c=getchar();
    while(c<'0'||c>'9') {if(c=='-') flag=-1; c=getchar();}
    while(c>='0'&&c<='9') {sum=sum*10+c-'0'; c=getchar();}
    return sum*flag;
}

const int N=1e5+10;
int n,q,m;
LL tr[N<<2],tag1[N<<2],tag2[N<<2];

void add(int nd,int l,int r,LL x1,LL x2) {
    tag1[nd]=(tag1[nd]*x2+x1)%m;
    tag2[nd]=(tag2[nd]*x2)%m;
    tr[nd]=(tr[nd]*x2%m+(r-l+1)*x1%m)%m;
}

void pushdown(int nd,int l,int r) {
    int mid=l+r>>1;
    add(nd<<1,l,mid,tag1[nd],tag2[nd]);
    add(nd<<1|1,mid+1,r,tag1[nd],tag2[nd]);
    tag1[nd]=0; tag2[nd]=1;
}

void pushup(int nd) {
    tr[nd]=(tr[nd<<1]+tr[nd<<1|1])%m;
}

void change(int nd,int l,int r,int x,int y,LL x1,LL x2) {
    if(r<x||l>y) return ;
    if(l>=x&&r<=y) return add(nd,l,r,x1,x2);
    pushdown(nd,l,r);
    int mid=l+r>>1;
    change(nd<<1,l,mid,x,y,x1,x2);
    change(nd<<1|1,mid+1,r,x,y,x1,x2);
    pushup(nd);
}

LL ask(int nd,int l,int r,int x,int y) {
    if(r<x||l>y) return 0;
    if(l>=x&&r<=y) return tr[nd];
    pushdown(nd,l,r);
    int mid=l+r>>1;
    return (ask(nd<<1,l,mid,x,y)+ask(nd<<1|1,mid+1,r,x,y))%m;
}

int main() {
    // freopen("a.in","r",stdin);
    // freopen("a.out","w",stdout);

    n=read(); q=read(); m=read();
    for(int i=1;i<=n*4;i++) {
        tag1[i]=0;
        tag2[i]=1;
    }
    for(int i=1;i<=n;i++) {
        LL x=read();
        change(1,1,n,i,i,x,1);
    }
    while(q--) {
        int opt=read(),x=read(),y=read();
        LL k;
        if(opt==1) {
            k=read();
            change(1,1,n,x,y,0,k);
        }
        else if(opt==2) {
            k=read();
            change(1,1,n,x,y,k,1);
        }
        else {
            cout<<ask(1,1,n,x,y)<<'\n';
        }
    }

    return 0;
}

P1471 方差

对于平均数,这是很好维护的,只需维护区间和即可。

对于方差,我们利用高中数学知识将其化成如下形式:

\[s^2=\frac{1}{n} \sum_{i=1}^{n} (A_i-\overline{A})^2=\frac{1}{n} \sum_{i=1}^{n} A_i^2 -\overline{A}^2 \]

对于这个玩意,维护区间平方和即可,考虑修改对查询的影响,若给\(a_{l\sim r}+k\),那么区间平方和为 \((a_{l}+k)^2+...+(a_{r}+k)^2\)\(a_{l}^2+...+a_{r}^2+2k(a_l+...+a_r)+(r-l+1)\times k^2\).

对于上面的式子,显然是好维护的,维护区间平方和,区间和,即可实现更新。

点击查看代码
#include<bits/stdc++.h>
using namespace std;

typedef long long LL;
typedef unsigned long long ULL;
LL read() {
    LL sum=0,flag=1; char c=getchar();
    while(c<'0'||c>'9') {if(c=='-') flag=-1; c=getchar();}
    while(c>='0'&&c<='9') {sum=sum*10+c-'0'; c=getchar();}
    return sum*flag;
}

const int N=1e5+10;
int n,m;
double sum1[N<<2],sum2[N<<2],tag[N<<2];
struct node {
    double s1,s2;
};

void add(int nd,int l,int r,double k) {
    tag[nd]+=k;
    sum2[nd]=sum2[nd]+2*k*sum1[nd]+(double)(r-l+1)*k*k;
    sum1[nd]+=(r-l+1)*k;
}

void pushdown(int nd,int l,int r) {
    int mid=l+r>>1;
    if(!tag[nd]) return ;
    add(nd<<1,l,mid,tag[nd]);
    add(nd<<1|1,mid+1,r,tag[nd]);
    tag[nd]=0;
}

void pushup(int nd) {
    sum1[nd]=sum1[nd<<1]+sum1[nd<<1|1];
    sum2[nd]=sum2[nd<<1]+sum2[nd<<1|1];
}

void change(int nd,int l,int r,int x,int y,double k) {
    if(r<x||l>y) return ;
    if(l>=x&&r<=y) return add(nd,l,r,k);
    int mid=l+r>>1;
    pushdown(nd,l,r);
    change(nd<<1,l,mid,x,y,k);
    change(nd<<1|1,mid+1,r,x,y,k);
    pushup(nd);
}

node query(int nd,int l,int r,int x,int y) {
    if(r<x||l>y) return {0,0};
    if(l>=x&&r<=y) return {sum1[nd],sum2[nd]};
    pushdown(nd,l,r);
    int mid=l+r>>1;
    node x1=query(nd<<1,l,mid,x,y);
    node x2=query(nd<<1|1,mid+1,r,x,y);
    return {x1.s1+x2.s1,x1.s2+x2.s2};
}

int main() {
    // freopen("a.in","r",stdin);
    // freopen("a.out","w",stdout);

    n=read(); m=read();
    for(int i=1;i<=n;i++) {
        double x; cin>>x;
        change(1,1,n,i,i,x);
    }
    while(m--) {
        int opt=read(),x=read(),y=read();
        double k;
        if(opt==1) {
            cin>>k;
            change(1,1,n,x,y,k);
        }
        else {
            node ans=query(1,1,n,x,y);
            if(opt==2) {
                printf("%.4lf\n",ans.s1*1.0/(y*1.0-x*1.0+1.0));
            }
            else {
                double avg=ans.s1*1.0/((y-x+1)*1.0);
                double kkk=ans.s2*1.0/((y-x+1)*1.0);
                printf("%.4lf\n",kkk-avg*avg);
            }
        }
    }

    return 0;
}

P4513 小白逛公园