莫队学习笔记
莫队
在此膜拜莫涛大佬以及同机房的莫队@Zkl21 。
普通莫队
先来考虑一个极其简单的问题:
给你一个序列 a,有多组询问,每次询问 [l, r] 的和
一眼前缀和,但是我们也可以用莫队大材小用地做这道题。
我们可以维护一个左端点 \(L\) 和 \(R\),我们可以发现,维护了这两个端点以及 \([L, R]\) 之间的信息之后就可以很轻松的计算出 \([L, R + 1]\),\([L, R - 1]\),\([L - 1, R]\),\([L + 1, R]\) 这些区间的信息,于是就可以每个询问都这么转移。
当然可以构造数据来卡掉它,比如下面的数据:
1 1000000
999999 1000000
1 1000000
999999 1000000
1 1000000
999999 1000000
...
每次询问都朴素转移的话,就和暴力没什么区别了,所以我们需要将所有询问离线下来,再排个序,最后按顺序输出即可。
排序方法
这里用的是一种非常简单的排序规则:分块。按照左端点右端点所在块来排序。块长一般取 \(\sqrt N\),这样可以保证时间复杂度为 \(O(N\sqrt N)\)。
bool operator < (const Q& q)
{
if(pos[l] != pos[q.l]) return pos[l] < pos[q.l];
return pos[r] < pos[q.r];
}
注意
上面说的区间 \([L, R]\) 向别的区间转移时,有 4 个操作:
while(R < q[i].r) add(++ R);
while(L > q[i].l) add(-- L);
while(R > q[i].r) sub(R --);
while(L < q[i].l) sub(L ++);
这四个操作的顺序是有一些讲究的,比如我们现在区间为 \([2, 4]\),要转移到 \([9, 10]\),假如我们采取了错误的顺序转移,可能会出现例如 \([9, 4]\) 这样的情况,这时维护的区间为负数,有的时候可能是不会出问题的,但是有时候如果用数据结构维护可能会导致越界等一系列问题,因此我们要采取正确的循环顺序。下面的表是从 OIwiki 上拿过来的,里面有所有的循环顺序以及是否正确:
循环顺序 | 正确性 | 反例或注释 |
---|---|---|
l--,l++,r--,r++ |
错误 | \(l<r<l'<r'\) |
l--,l++,r++,r-- |
错误 | \(l<r<l'<r'\) |
l--,r--,l++,r++ |
错误 | \(l<r<l'<r'\) |
l--,r--,r++,l++ |
正确 | 证明较繁琐 |
l--,r++,l++,r-- |
正确 | |
l--,r++,r--,l++ |
正确 | |
l++,l--,r--,r++ |
错误 | \(l<r<l'<r'\) |
l++,l--,r++,r-- |
错误 | \(l<r<l'<r'\) |
l++,r++,l--,r-- |
错误 | \(l<r<l'<r'\) |
l++,r++,r--,l-- |
错误 | \(l<r<l'<r'\) |
l++,r--,l--,r++ |
错误 | \(l<r<l'<r'\) |
l++,r--,r++,l-- |
错误 | \(l<r<l'<r'\) |
全部 24 种排列中只有 6 种是正确的,其中有 2 种的证明较繁琐,这里只给出其中 4 种的证明。
这 4 种正确写法的共同特点是,前两步先扩大区间(
l--
或r++
),后两步再缩小区间(l++
或r--
)。这样写,前两步是扩大区间,可以保持 \(l\le r+1\);执行完前两步后,\(l\le l'\le r'\le r\) 一定成立,再执行后两步只会把区间缩小到 \([l',r']\),依然有 \(l\le r+1\),因此这样写是正确的。
例题
P1494 [国家集训队] 小 Z 的袜子
这道题要统计区间相同数对的数量,用莫队可以轻松处理。
假设我们知道了 \([L, R]\) 区间内每种颜色的数量 \(num\),当我们要增加一个颜色为 \(c\) 的元素时,相同颜色的数对 \(cnt\) 会增加 \(num_c\) 个,同时再让 \(num_c\) 加 \(1\)。减去颜色时也同理。
然后每个询问的答案就是 \(\frac{cnt_i}{C^2_{r - l + 1}}\)。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10;
int n, m;
int a[N];
int pos[N];
struct Q
{
int op, l, r;
bool operator < (const Q& q)
{
if(pos[l] != pos[q.l]) return pos[l] < pos[q.l];
return pos[r] < pos[q.r];
}
}q[N];
struct fraction
{
ll up, down;
ll gcd(ll a, ll b)
{
if(!b) return a;
return gcd(b, a % b);
}
void divide()
{
ll d = gcd(up, down);
up /= d, down /= d;
}
void print()
{
printf("%lld/%lld\n", up, down);
}
fraction operator + (const fraction& a)
{
fraction ret;
ret.down = down * a.down;
ret.up = up * a.down + a.up * down;
ret.divide();
return ret;
}
fraction operator - (const fraction& a)
{
fraction res;
res.down = down * a.down;
res.up = up * a.down - a.up * down;
res.divide();
return res;
}
fraction operator * (const fraction& a)
{
fraction res;
res.down = a.down * down;
res.up = a.up * up;
res.divide();
return res;
}
fraction operator / (const fraction& a)
{
fraction res;
res.down = a.up * down;
res.up = a.down * up;
res.divide();
return res;
}
}ans[N]; //我也不知道我写这么一大堆有什么用
ll cnt;
int num[N];
void add(int pos)
{
cnt += num[a[pos]];
num[a[pos]] ++;
}
void sub(int pos)
{
num[a[pos]] --;
cnt -= num[a[pos]];
}
ll C(ll n, ll m)
{
ll res = 1;
for(int i = n - m + 1; i <= n; i ++ ) res *= i;
for(int i = 1; i <= m; i ++ ) res /= i;
return res;
}
int main()
{
scanf("%d%d", &n, &m);
int _ = sqrt(n);
for(int i = 1; i <= n; i ++ )
{
pos[i] = (i - 1) / _;
scanf("%d", &a[i]);
}
for(int i = 1; i <= m; i ++ )
{
int l, r;
scanf("%d%d", &l, &r);
q[i] = {i, l, r};
}
sort(q + 1, q + m + 1);
int L = 1, R = 0;
for(int i = 1; i <= m; i ++ )
{
while(R < q[i].r) add(++ R);
while(L > q[i].l) add(-- L);
while(R > q[i].r) sub(R --);
while(L < q[i].l) sub(L ++);
if(q[i].l == q[i].r) ans[q[i].op] = {0, 1};
else
{
ans[q[i].op] = {cnt, C(q[i].r - q[i].l + 1, 2)};
ans[q[i].op].divide();
}
}
for(int i = 1; i <= m; i ++ )
ans[i].print();
return 0;
}
P2709 小B的询问
这个题的信息更好维护,开个桶即可。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5e4 + 10;
int a[N], cnt[N];
int pos[N];
ll ans[N];
int n, m, k;
ll res;
void add(int pos)
{
res -= (cnt[a[pos]] * cnt[a[pos]]);
cnt[a[pos]] ++;
res += (cnt[a[pos]] * cnt[a[pos]]);
}
void sub(int pos)
{
res -= (cnt[a[pos]] * cnt[a[pos]]);
cnt[a[pos]] --;
res += (cnt[a[pos]] * cnt[a[pos]]);
}
struct Q
{
int l, r, k;
bool operator < (const Q &q) const
{
if(pos[l] == pos[q.l]) return pos[r] < pos[q.r];
else return pos[l] < pos[q.l];
}
}q[N];
int main()
{
n = read(), m = read(), k = read();
int siz = sqrt(n);
for(int i = 1; i <= n; i ++ )
{
a[i] = read();
pos[i] = i / siz;
}
for(int i = 0; i < m; i ++ )
{
q[i].l = read(), q[i].r = read();
q[i].k = i;
}
sort(q, q + m);
int l = 1, r = 0;
for(int i = 0; i < m; i ++ )
{
while(q[i].l < l) add(-- l);
while(q[i].r > r) add(++ r);
while(q[i].l > l) sub(l ++);
while(q[i].r < r) sub(r --);
ans[q[i].k] = res;
}
for(int i = 0; i < m; i ++ )
cout << ans[i] << endl;
return 0;
}
带修莫队
带修莫队可以处理修改,同时时间复杂度也会相应的增加一部分,一般为 \(O(n^{\frac{5}{3}})\)。
带修莫队与普通莫队不同的一点是,增加了一维时间 \(T\),因此我们在每次从 \([L, R, T]\) 向询问 \([L_i, R_i, T_i]\) 转移时,首先先将 \(T\) 挪到该询问的时间,同时将 \([T_i, T]\) 或 \([T, T_i]\) 这之间的修改全部撤销或者执行。之后再进行区间左右端点的转移即可。
时间复杂度
带修莫队一般块长取到 \(n^{\frac{2}{3}}\),这样可以保证时间复杂度为最优。具体证明我也不会,从 OIwiki 上搬过来的证明见下:
带修莫队排序的第二关键字是右端点所在块编号,不同于普通莫队。
想一想,如果不把右端点分块:
- 乱序的右端点对于每个询问会移动 \(n\) 次。
- 有序的右端点会带来乱序的时间,每次询问会移动 \(t\) 次。
无论哪一种情况,带来的时间开销都无法接受。
接下来分析时间复杂度。
设块长为 \(s\),则有 \(\frac{n}{s}\) 个块。对于块 \(i\) 和块 \(j\),记有 \(q_{i,j}\) 个询问的左端点位于块 \(i\),右端点位于块 \(j\)。
每「组」左右端点不换块的询问 \((i,j)\),端点每次移动 \(O(s)\) 次,时间单调递增,\(O(t)\)。
左右端点换块的时间忽略不计。
表示一下就是:
考虑求导求此式极小值。设 \(f(s)=ms+\frac{n^2t}{s^2}\)。那 \(f'(s)=m-\frac{2n^2t}{s^3}=0\)。
得 \(s=\sqrt[3]{\frac{2n^2t}{m}}=\frac{2^\frac{1}{3}n^\frac23t^\frac13}{m^\frac13}=s_0\)。
也就是当块长取 \(\frac{n^\frac23t^\frac13}{m^\frac13}\) 时有最优时间复杂度 \(O(n^\frac23m^\frac23t^\frac13)\)。
常说的 \(O(n^\frac53)\) 便是把 \(n,m,t\) 当做同数量级的时间复杂度。
实际操作中还是推荐设定 \(n^{\frac{2}{3}}\) 为块长。
例题
P1903 [国家集训队] 数颜色 / 维护队列
本题为带修莫队板子题,开个桶维护即可。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6 + 10;
int a[N], pos[N];
int n, m;
struct query
{
int op, l, r, time;
bool operator < (const query& Q)
{
if(pos[l] == pos[Q.l])
{
if(pos[r] == pos[Q.r])
return time < Q.time;
return pos[r] < pos[Q.r];
}
return pos[l] < pos[Q.l];
}
}q[N];
int qtt;
struct change
{
int p, v;
}c[N];
int num[N], ans[N];
int cnt, now;
inline void add(int p)
{
if(!num[a[p]]) cnt ++;
num[a[p]] ++;
}
inline void sub(int p)
{
num[a[p]] --;
if(!num[a[p]]) cnt --;
}
int main()
{
n = read(), m = read();
int _ = pow(n, 2.0 / 3.0);
for(int i = 1; i <= n; i ++ )
{
pos[i] = (i - 1) / _;
a[i] = read();
}
for(int i = 1; i <= m; i ++ )
{
char op[5];
scanf("%s", op);
if(op[0] == 'Q')
{
int l = read(), r = read();
qtt ++;
q[qtt] = {qtt, l, r, i};
}
else
{
int p = read(), v = read();
c[i] = {p, v};
}
}
sort(q + 1, q + qtt + 1);
int L = 1, R = 0;
for(int i = 1; i <= qtt; i ++ )
{
while(now < q[i].time)
{
now ++;
if(c[now].p >= L && c[now].p <= R)
{
sub(c[now].p);
swap(c[now].v, a[c[now].p]);
add(c[now].p);
}
else if(c[now].p) swap(c[now].v, a[c[now].p]);
}
while(now > q[i].time)
{
if(c[now].p >= L && c[now].p <= R)
{
sub(c[now].p);
swap(c[now].v, a[c[now].p]);
add(c[now].p);
}
else if(c[now].p) swap(c[now].v, a[c[now].p]);
now --;
}
while(R < q[i].r) add(++ R);
while(L > q[i].l) add(-- L);
while(R > q[i].r) sub(R --);
while(L < q[i].l) sub(L ++);
ans[q[i].op] = cnt;
}
for(int i = 1; i <= qtt; i ++ )
printf("%d\n", ans[i]);
return 0;
}
树上莫队
树上莫队一般有两种形式:将树跑出括号序再在括号序上跑莫队、直接在树上跑莫队。接下来先介绍第一种。
补充:括号序
补充这玩意的原因是我直到学这个之前我一直不知道括号序是什么
括号序即为对一颗树进行 dfs 的过程中,刚刚进入时加入一次,在退出时再加入一次,求括号序的代码如下:
void dfs1(int u, int father)
{
id[++ idexx] = u;
for(int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if(j == father) continue;
dfs1(j, u);
}
id[++ idexx] = u;
}
比如下面这棵树:
它的括号序即为:\(1, 2, 3, 3, 4, 4, 2, 5, 5, 1\)
每相同的两个点之间即为该点的子树。
我们把括号标上:\((1, (2, (3, 3), (4, 4), 2), (5, 5), 1)\)
例题
P4074 [WC2013] 糖果公园
本题将括号序跑下来,再进行一个带修莫队。
用 \(vis_u\) 表示 \(u\) 该点是否有贡献,出现一次时就将 \(vis_u\) 异或上 \(1\) 即可。
注意:可能两个点括号序之间的点会少一个 \(lca\),所以需要特判一个点是否为 \(lca\),如果都不是则加上 \(lca\) 的贡献。同时在对询问进行处理时记得处理好左右端点。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read()
{
int x = 0, f = 1;
char ch = getchar();
while(ch < '0' || ch > '9')
{
if(ch == '-') f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
x = x * 10 + ch - '0';
ch = getchar();
}
return x * f;
}
const int N = 1e6 + 10;
int h[N], e[N], ne[N], idx;
int v[N], w[N], c[N];
int pos[N];
int n, m, Q;
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
struct query
{
int op, l, r, time;
bool operator < (const query& Q) const
{
if(pos[l] == pos[Q.l])
{
if(pos[r] == pos[Q.r])
return time < Q.time;
return pos[r] < pos[Q.r];
}
return pos[l] < pos[Q.l];
}
}q[N];
int qtt;
int last[N];
struct change
{
int x, last, to;
}ch[N];
int fa[N], dep[N], son[N], siz[N];
int f[N], g[N], idexx;
int id[N];
void dfs1(int u, int father)
{
f[u] = ++ idexx;
id[idexx] = u;
fa[u] = father, dep[u] = dep[father] + 1;
siz[u] = 1;
int maxsize = -1;
for(int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if(j == father) continue;
dfs1(j, u);
siz[u] += siz[j];
if(siz[j] > maxsize)
{
maxsize = siz[j];
son[u] = j;
}
}
g[u] = ++ idexx;
id[idexx] = u;
}
int top[N];
void dfs2(int u, int t)
{
top[u] = t;
if(!son[u]) return;
dfs2(son[u], t);
for(int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if(j == fa[u] || j == son[u]) continue;
dfs2(j, j);
}
}
int tim;
inline int lca(int x, int y)
{
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
if(dep[x] < dep[y]) return x;
return y;
}
int L = 1, R = 0, T = 0;
bool vis[N];
ll num[N], ans[N];
ll cnt;
void add(int p)
{
if(vis[p])
{
cnt -= (ll)v[c[p]] * w[num[c[p]]];
num[c[p]] --;
}
else
{
num[c[p]] ++;
cnt += (ll)v[c[p]] * w[num[c[p]]];
}
vis[p] ^= 1;
}
void modify(int x, int t)
{
if(vis[x])
{
add(x);
c[x] = t;
add(x);
}
else c[x] = t;
}
int main()
{
memset(h, -1, sizeof h);
n = read(), m = read(), Q = read();
for(int i = 1; i <= m; i ++ ) v[i] = read();
for(int i = 1; i <= n; i ++ ) w[i] = read();
for(int i = 1; i < n; i ++ )
{
int a = read(), b = read();
add(a, b), add(b, a);
}
dfs1(1, 1);
dfs2(1, 1);
int _ = pow(idexx, 2.0 / 3.0);
for(int i = 1; i <= idexx; i ++ ) pos[i] = (i - 1) / _;
for(int i = 1; i <= n; i ++ )
{
c[i] = read();
last[i] = c[i];
}
for(int i = 1; i <= Q; i ++ )
{
int op = read(), x = read(), y = read();
if(op == 0)
{
tim ++;
ch[tim] = {x, last[x], y};
last[x] = y;
}
else
{
qtt ++;
q[qtt].op = qtt;
q[qtt].time = tim;
if(f[x] > f[y]) swap(x, y);
if(lca(x, y) == x) q[qtt].l = f[x];
else q[qtt].l = g[x];
q[qtt].r = f[y];
}
}
sort(q + 1, q + qtt + 1);
L = 1, R = 0, T = 0;
for(int i = 1; i <= qtt; i ++ )
{
while(T < q[i].time)
{
T ++;
modify(ch[T].x, ch[T].to);
}
while(T > q[i].time)
{
modify(ch[T].x, ch[T].last);
T --;
}
while(R < q[i].r) add(id[++ R]);
while(L > q[i].l) add(id[-- L]);
while(R > q[i].r) add(id[R --]);
while(L < q[i].l) add(id[L ++]);
int x = id[L], y = id[R];
int anc = lca(x, y);
if(x != anc && y != anc)
{
add(anc);
ans[q[i].op] = cnt;
add(anc);
}
else ans[q[i].op] = cnt;
}
for(int i = 1; i <= qtt; i ++ )
printf("%lld\n", ans[i]);
return 0;
}