莫队学习笔记

crimsonawa / 2023-05-03 / 原文

莫队

在此膜拜莫涛大佬以及同机房的莫队@Zkl21 。

普通莫队

先来考虑一个极其简单的问题:

给你一个序列 a,有多组询问,每次询问 [l, r] 的和

一眼前缀和,但是我们也可以用莫队大材小用地做这道题。

我们可以维护一个左端点 \(L\)\(R\),我们可以发现,维护了这两个端点以及 \([L, R]\) 之间的信息之后就可以很轻松的计算出 \([L, R + 1]\)\([L, R - 1]\)\([L - 1, R]\)\([L + 1, R]\) 这些区间的信息,于是就可以每个询问都这么转移。

当然可以构造数据来卡掉它,比如下面的数据:

1 1000000
999999 1000000
1 1000000
999999 1000000
1 1000000
999999 1000000
...

每次询问都朴素转移的话,就和暴力没什么区别了,所以我们需要将所有询问离线下来,再排个序,最后按顺序输出即可。

排序方法

这里用的是一种非常简单的排序规则:分块。按照左端点右端点所在块来排序。块长一般取 \(\sqrt N\),这样可以保证时间复杂度为 \(O(N\sqrt N)\)

bool operator < (const Q& q)
{
    if(pos[l] != pos[q.l]) return pos[l] < pos[q.l];
    return pos[r] < pos[q.r]; 
}

注意

上面说的区间 \([L, R]\) 向别的区间转移时,有 4 个操作:

while(R < q[i].r) add(++ R);
while(L > q[i].l) add(-- L);
while(R > q[i].r) sub(R --);
while(L < q[i].l) sub(L ++);

这四个操作的顺序是有一些讲究的,比如我们现在区间为 \([2, 4]\),要转移到 \([9, 10]\),假如我们采取了错误的顺序转移,可能会出现例如 \([9, 4]\) 这样的情况,这时维护的区间为负数,有的时候可能是不会出问题的,但是有时候如果用数据结构维护可能会导致越界等一系列问题,因此我们要采取正确的循环顺序。下面的表是从 OIwiki 上拿过来的,里面有所有的循环顺序以及是否正确:

循环顺序 正确性 反例或注释
l--,l++,r--,r++ 错误 \(l<r<l'<r'\)
l--,l++,r++,r-- 错误 \(l<r<l'<r'\)
l--,r--,l++,r++ 错误 \(l<r<l'<r'\)
l--,r--,r++,l++ 正确 证明较繁琐
l--,r++,l++,r-- 正确
l--,r++,r--,l++ 正确
l++,l--,r--,r++ 错误 \(l<r<l'<r'\)
l++,l--,r++,r-- 错误 \(l<r<l'<r'\)
l++,r++,l--,r-- 错误 \(l<r<l'<r'\)
l++,r++,r--,l-- 错误 \(l<r<l'<r'\)
l++,r--,l--,r++ 错误 \(l<r<l'<r'\)
l++,r--,r++,l-- 错误 \(l<r<l'<r'\)

全部 24 种排列中只有 6 种是正确的,其中有 2 种的证明较繁琐,这里只给出其中 4 种的证明。

这 4 种正确写法的共同特点是,前两步先扩大区间(l--r++),后两步再缩小区间(l++r--)。这样写,前两步是扩大区间,可以保持 \(l\le r+1\);执行完前两步后,\(l\le l'\le r'\le r\) 一定成立,再执行后两步只会把区间缩小到 \([l',r']\),依然有 \(l\le r+1\),因此这样写是正确的。

例题

P1494 [国家集训队] 小 Z 的袜子

这道题要统计区间相同数对的数量,用莫队可以轻松处理。

假设我们知道了 \([L, R]\) 区间内每种颜色的数量 \(num\),当我们要增加一个颜色为 \(c\) 的元素时,相同颜色的数对 \(cnt\) 会增加 \(num_c\) 个,同时再让 \(num_c\)\(1\)。减去颜色时也同理。

然后每个询问的答案就是 \(\frac{cnt_i}{C^2_{r - l + 1}}\)

代码:


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N = 1e5 + 10;

int n, m;
int a[N];
int pos[N];

struct Q
{
    int op, l, r;
    bool operator < (const Q& q)
    {
        if(pos[l] != pos[q.l]) return pos[l] < pos[q.l];
        return pos[r] < pos[q.r]; 
    }
}q[N];

struct fraction
{
    ll up, down;

    ll gcd(ll a, ll b)
    {
        if(!b) return a;
        return gcd(b, a % b);
    }

    void divide()
    {
        ll d = gcd(up, down);
        up /= d, down /= d;
    }

    void print()
    {
        printf("%lld/%lld\n", up, down);
    }

    fraction operator + (const fraction& a)
    {
        fraction ret;
        ret.down = down * a.down;
        ret.up = up * a.down + a.up * down;
        ret.divide();
        return ret;
    }

    fraction operator - (const fraction& a)
    {
        fraction res;
        res.down = down * a.down;
        res.up = up * a.down - a.up * down;
        res.divide();
        return res;
    }

    fraction operator * (const fraction& a)
    {
        fraction res;
        res.down = a.down * down;
        res.up = a.up * up;
        res.divide();
        return res;
    }

    fraction operator / (const fraction& a)
    {
        fraction res;
        res.down = a.up * down;
        res.up = a.down * up;
        res.divide();
        return res;
    }
}ans[N];    //我也不知道我写这么一大堆有什么用

ll cnt;
int num[N];

void add(int pos)
{
    cnt += num[a[pos]];
    num[a[pos]] ++;
}

void sub(int pos)
{
    num[a[pos]] --;
    cnt -= num[a[pos]];
}

ll C(ll n, ll m)
{
    ll res = 1;
    for(int i = n - m + 1; i <= n; i ++ ) res *= i;
    for(int i = 1; i <= m; i ++ ) res /= i;
    return res;
}

int main()
{
    scanf("%d%d", &n, &m);
    int _ = sqrt(n);
    for(int i = 1; i <= n; i ++ )
    {
        pos[i] = (i - 1) / _;
        scanf("%d", &a[i]);
    }
    
    for(int i = 1; i <= m; i ++ )
    {
        int l, r;
        scanf("%d%d", &l, &r);
        q[i] = {i, l, r};
    }

    sort(q + 1, q + m + 1);

    int L = 1, R = 0;
    for(int i = 1; i <= m; i ++ )
    {
        while(R < q[i].r) add(++ R);
        while(L > q[i].l) add(-- L);
        while(R > q[i].r) sub(R --);
        while(L < q[i].l) sub(L ++);
        if(q[i].l == q[i].r) ans[q[i].op] = {0, 1};
        else
        {
            ans[q[i].op] = {cnt, C(q[i].r - q[i].l + 1, 2)};
            ans[q[i].op].divide();
        }
    }

    for(int i = 1; i <= m; i ++ )
        ans[i].print();
    
    return 0;
}

P2709 小B的询问

这个题的信息更好维护,开个桶即可。

代码:


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N = 5e4 + 10;
int a[N], cnt[N];
int pos[N];
ll ans[N];
int n, m, k;
ll res;

void add(int pos)
{
    res -= (cnt[a[pos]] * cnt[a[pos]]);
    cnt[a[pos]] ++;
    res += (cnt[a[pos]] * cnt[a[pos]]);
}

void sub(int pos)
{
    res -= (cnt[a[pos]] * cnt[a[pos]]);
    cnt[a[pos]] --;
    res += (cnt[a[pos]] * cnt[a[pos]]);
}

struct Q
{
    int l, r, k;
    bool operator < (const Q &q) const
    {
        if(pos[l] == pos[q.l]) return pos[r] < pos[q.r];
        else return pos[l] < pos[q.l];
    }
}q[N];

int main()
{
    n = read(), m = read(), k = read();

    int siz = sqrt(n);

    for(int i = 1; i <= n; i ++ ) 
    {
        a[i] = read();
        pos[i] = i / siz;
    }

    for(int i = 0; i < m; i ++ )
    {
        q[i].l = read(), q[i].r = read();
        q[i].k = i;
    }

    sort(q, q + m);

    int l = 1, r = 0;

    for(int i = 0; i < m; i ++ )
    {
        while(q[i].l < l) add(-- l);
        while(q[i].r > r) add(++ r);
        while(q[i].l > l) sub(l ++);
        while(q[i].r < r) sub(r --);
        ans[q[i].k] = res;
    }

    for(int i = 0; i < m; i ++ )
        cout << ans[i] << endl;

    return 0;
}

带修莫队

带修莫队可以处理修改,同时时间复杂度也会相应的增加一部分,一般为 \(O(n^{\frac{5}{3}})\)

带修莫队与普通莫队不同的一点是,增加了一维时间 \(T\),因此我们在每次从 \([L, R, T]\) 向询问 \([L_i, R_i, T_i]\) 转移时,首先先将 \(T\) 挪到该询问的时间,同时将 \([T_i, T]\)\([T, T_i]\) 这之间的修改全部撤销或者执行。之后再进行区间左右端点的转移即可。

时间复杂度

带修莫队一般块长取到 \(n^{\frac{2}{3}}\),这样可以保证时间复杂度为最优。具体证明我也不会,从 OIwiki 上搬过来的证明见下:

带修莫队排序的第二关键字是右端点所在块编号,不同于普通莫队。

想一想,如果不把右端点分块:

  • 乱序的右端点对于每个询问会移动 \(n\) 次。
  • 有序的右端点会带来乱序的时间,每次询问会移动 \(t\) 次。

无论哪一种情况,带来的时间开销都无法接受。

接下来分析时间复杂度。

设块长为 \(s\),则有 \(\frac{n}{s}\) 个块。对于块 \(i\) 和块 \(j\),记有 \(q_{i,j}\) 个询问的左端点位于块 \(i\),右端点位于块 \(j\)

每「组」左右端点不换块的询问 \((i,j)\),端点每次移动 \(O(s)\) 次,时间单调递增,\(O(t)\)

左右端点换块的时间忽略不计。

表示一下就是:

\[\begin{aligned} &\sum_{i=1}^{\frac{n}{s}}\sum_{j=i+1}^{\frac{n}{s}}(q_{i,j}\cdot s+t)\\ =&ms+(\frac{n}{s})^2t\\ =&ms+\frac{n^2t}{s^2} \end{aligned} \]

考虑求导求此式极小值。设 \(f(s)=ms+\frac{n^2t}{s^2}\)。那 \(f'(s)=m-\frac{2n^2t}{s^3}=0\)

\(s=\sqrt[3]{\frac{2n^2t}{m}}=\frac{2^\frac{1}{3}n^\frac23t^\frac13}{m^\frac13}=s_0\)

也就是当块长取 \(\frac{n^\frac23t^\frac13}{m^\frac13}\) 时有最优时间复杂度 \(O(n^\frac23m^\frac23t^\frac13)\)

常说的 \(O(n^\frac53)\) 便是把 \(n,m,t\) 当做同数量级的时间复杂度。

实际操作中还是推荐设定 \(n^{\frac{2}{3}}\) 为块长。

例题

P1903 [国家集训队] 数颜色 / 维护队列

本题为带修莫队板子题,开个桶维护即可。


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N = 1e6 + 10;
int a[N], pos[N];
int n, m;
struct query
{
    int op, l, r, time;
    bool operator < (const query& Q)
    {
        if(pos[l] == pos[Q.l])
        {
            if(pos[r] == pos[Q.r])
                return time < Q.time;
            return pos[r] < pos[Q.r];
        }
        return pos[l] < pos[Q.l];
    }
}q[N];

int qtt;
struct change
{
    int p, v;
}c[N];

int num[N], ans[N];
int cnt, now;
inline void add(int p)
{
    if(!num[a[p]]) cnt ++;
    num[a[p]] ++;
}

inline void sub(int p)
{
    num[a[p]] --;
    if(!num[a[p]]) cnt --;
}

int main()
{
    n = read(), m = read();
    int _ = pow(n, 2.0 / 3.0);
    for(int i = 1; i <= n; i ++ ) 
    {
        pos[i] = (i - 1) / _;
        a[i] = read();
    }

    for(int i = 1; i <= m; i ++ )
    {
        char op[5];
        scanf("%s", op);
        if(op[0] == 'Q')
        {
            int l = read(), r = read();
            qtt ++;
            q[qtt] = {qtt, l, r, i};
        }
        else
        {
            int p = read(), v = read();
            c[i] = {p, v};
        }
    }

    sort(q + 1, q + qtt + 1);

    int L = 1, R = 0;
    for(int i = 1; i <= qtt; i ++ )
    {
        while(now < q[i].time)
        {
            now ++;
            if(c[now].p >= L && c[now].p <= R)
            {
                sub(c[now].p);
                swap(c[now].v, a[c[now].p]);
                add(c[now].p);
            }
            else if(c[now].p) swap(c[now].v, a[c[now].p]);
        }
        while(now > q[i].time)
        {
            if(c[now].p >= L && c[now].p <= R)
            {
                sub(c[now].p);
                swap(c[now].v, a[c[now].p]);
                add(c[now].p);
            }
            else if(c[now].p) swap(c[now].v, a[c[now].p]);
            now --;
        }

        while(R < q[i].r) add(++ R);
        while(L > q[i].l) add(-- L);
        while(R > q[i].r) sub(R --);
        while(L < q[i].l) sub(L ++);
        ans[q[i].op] = cnt;
    }

    for(int i = 1; i <= qtt; i ++ ) 
        printf("%d\n", ans[i]);
    
    return 0;
}

树上莫队

树上莫队一般有两种形式:将树跑出括号序再在括号序上跑莫队、直接在树上跑莫队。接下来先介绍第一种。

补充:括号序

补充这玩意的原因是我直到学这个之前我一直不知道括号序是什么

括号序即为对一颗树进行 dfs 的过程中,刚刚进入时加入一次,在退出时再加入一次,求括号序的代码如下:

void dfs1(int u, int father)
{
    id[++ idexx] = u;
    for(int i = h[u]; i != -1; i = ne[i])
    {
        int j = e[i];
        if(j == father) continue;
        dfs1(j, u);
    }
    id[++ idexx] = u;
}

比如下面这棵树:

iL1jE5.png

它的括号序即为:\(1, 2, 3, 3, 4, 4, 2, 5, 5, 1\)

每相同的两个点之间即为该点的子树。

我们把括号标上:\((1, (2, (3, 3), (4, 4), 2), (5, 5), 1)\)

例题

P4074 [WC2013] 糖果公园

本题将括号序跑下来,再进行一个带修莫队。

\(vis_u\) 表示 \(u\) 该点是否有贡献,出现一次时就将 \(vis_u\) 异或上 \(1\) 即可。

注意:可能两个点括号序之间的点会少一个 \(lca\),所以需要特判一个点是否为 \(lca\),如果都不是则加上 \(lca\) 的贡献。同时在对询问进行处理时记得处理好左右端点。

代码:


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

inline int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        if(ch == '-') f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}

const int N = 1e6 + 10;
int h[N], e[N], ne[N], idx;
int v[N], w[N], c[N];
int pos[N];
int n, m, Q;

void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}

struct query
{
    int op, l, r, time;
    bool operator < (const query& Q) const
    {
        if(pos[l] == pos[Q.l])
        {
            if(pos[r] == pos[Q.r])
                return time < Q.time;
            return pos[r] < pos[Q.r];
        }
        return pos[l] < pos[Q.l];
    }
}q[N];
int qtt;

int last[N];
struct change
{
    int x, last, to;
}ch[N];

int fa[N], dep[N], son[N], siz[N];
int f[N], g[N], idexx;
int id[N];
void dfs1(int u, int father)
{
    f[u] = ++ idexx;
    id[idexx] = u;
    fa[u] = father, dep[u] = dep[father] + 1;
    siz[u] = 1;
    int maxsize = -1;
    for(int i = h[u]; i != -1; i = ne[i])
    {
        int j = e[i];
        if(j == father) continue;
        dfs1(j, u);
        siz[u] += siz[j];
        if(siz[j] > maxsize)
        {
            maxsize = siz[j];
            son[u] = j;
        }
    }
    g[u] = ++ idexx;
    id[idexx] = u;
}

int top[N];
void dfs2(int u, int t)
{
    top[u] = t;
    if(!son[u]) return;
    dfs2(son[u], t);
    for(int i = h[u]; i != -1; i = ne[i])
    {
        int j = e[i];
        if(j == fa[u] || j == son[u]) continue;
        dfs2(j, j);
    }
}

int tim;

inline int lca(int x, int y)
{
    while(top[x] != top[y])
    {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    if(dep[x] < dep[y]) return x;
    return y;
}

int L = 1, R = 0, T = 0;
bool vis[N];
ll num[N], ans[N];
ll cnt;
void add(int p)
{
    if(vis[p])
    {
        cnt -= (ll)v[c[p]] * w[num[c[p]]];
        num[c[p]] --;
    }
    else
    {
        num[c[p]] ++;
        cnt += (ll)v[c[p]] * w[num[c[p]]];
    }
    vis[p] ^= 1;
}

void modify(int x, int t)
{
    if(vis[x])
    {
        add(x);
        c[x] = t;
        add(x);
    }
    else c[x] = t;
}

int main()
{
    memset(h, -1, sizeof h);

    n = read(), m = read(), Q = read();
    for(int i = 1; i <= m; i ++ ) v[i] = read();
    for(int i = 1; i <= n; i ++ ) w[i] = read();

    for(int i = 1; i < n; i ++ )
    {
        int a = read(), b = read();
        add(a, b), add(b, a);
    }

    dfs1(1, 1);
    dfs2(1, 1);

    int _ = pow(idexx, 2.0 / 3.0);
    for(int i = 1; i <= idexx; i ++ ) pos[i] = (i - 1) / _;
    
    for(int i = 1; i <= n; i ++ )
    {
        c[i] = read();
        last[i] = c[i];
    }

    for(int i = 1; i <= Q; i ++ )
    {
        int op = read(), x = read(), y = read();
        if(op == 0)
        {
            tim ++;
            ch[tim] = {x, last[x], y};
            last[x] = y;
        }
        else
        {
            qtt ++;
            q[qtt].op = qtt;
            q[qtt].time = tim;
            if(f[x] > f[y]) swap(x, y);
            if(lca(x, y) == x) q[qtt].l = f[x];
            else q[qtt].l = g[x];
            q[qtt].r = f[y];
        }
    }

    sort(q + 1, q + qtt + 1);

    L = 1, R = 0, T = 0;
    for(int i = 1; i <= qtt; i ++ )
    {
        while(T < q[i].time)
        {
            T ++;
            modify(ch[T].x, ch[T].to);
        }
        while(T > q[i].time)
        {
            modify(ch[T].x, ch[T].last);
            T --;
        }
        while(R < q[i].r) add(id[++ R]);
        while(L > q[i].l) add(id[-- L]);
        while(R > q[i].r) add(id[R --]);
        while(L < q[i].l) add(id[L ++]);
        int x = id[L], y = id[R];
        int anc = lca(x, y);
        if(x != anc && y != anc)
        {
            add(anc);
            ans[q[i].op] = cnt;
            add(anc);
        }
        else ans[q[i].op] = cnt;
    }

    for(int i = 1; i <= qtt; i ++ ) 
        printf("%lld\n", ans[i]);

    return 0;
}

未完待续