2021 ICPC 网络赛 第二场 L Euler Function(势能线段树,欧拉函数,状态压缩)

jujujujuluo / 2024-07-17 / 原文

2021 ICPC 网络赛 第二场 L Euler Function

题意

给定序列,定义两个操作

  • \(l,r,x\)对区间\([l,r]\)的数乘\(x\)
  • \(l,r\)\(\sum \phi {a}_{i}\)

思路

注意欧拉函数的性质,若\(i\bmod p= 0\)\(\phi (i * p)=p*\phi (i)\),否则\(\phi(i * p) = (p - 1) * \phi (i)\)

因为\(x,w\)的值都小于\(100\),因此我们可以对线段树维护区间所有质因子的交集,当区间内交集有\(prime\),那么直接对区间乘\(prime\),否则递归到叶子节点,乘\(prime - 1\)(次数不会很多)

实现可以考虑直接用std::bitset维护每个区间的质因子状态,我这里用的是离散化之后再状压。

#include <bits/stdc++.h>
using namespace std;

#define int long long
const int N = 1e5 + 10;
int a[N];
std::vector<int> minp, primes;
vector<int> fac[110];
int id[110], phi[110];
const int mod = 998244353;
struct node {
    int sum;
    int mul;
    int st;
}tr[N << 2];
void sieve(int n) {
    minp.assign(n + 1, 0);
    primes.clear();
    
    phi[1] = 1;
    for (int i = 2; i <= n; i++) {
        if (minp[i] == 0) {
            minp[i] = i;
            id[i] = primes.size();
            primes.push_back(i);
            phi[i] = i - 1;
        }
        
        for (auto p : primes) {
            if (i * p > n) {
                break;
            }
            minp[i * p] = p;
            if (p == minp[i]) {
                phi[i * p] = phi[i] * p;
                break;
            } else {
                phi[i * p] = phi[i] * (p - 1);            
            }
        }
    }
}
void pushup(node &u, node &l, node &r) {
    u.sum = (l.sum + r.sum) % mod;
    u.st = l.st & r.st;
}
void pushup(int k) {
    pushup(tr[k], tr[k + k], tr[k + k + 1]);
}
void pushdown(int k) {
    if (tr[k].mul > 1) {
        auto &u = tr[k], &l = tr[k + k], &r = tr[k + k + 1];
        int x = u.mul;
        l.sum = l.sum * x % mod;
        r.sum = r.sum * x % mod;
        l.mul = l.mul * x % mod;
        r.mul = r.mul * x % mod;
        u.mul = 1;
    }
}
void build(int k, int l, int r) {
    tr[k].mul = 1;
    if (l == r) {
        tr[k].sum = phi[a[l]];
        for (auto &x : fac[a[l]]) tr[k].st |= (1ll << id[x]);
        return ;
    }
    int mid = l + r >> 1;
    build(k + k, l, mid);
    build(k + k + 1, mid + 1, r);
    pushup(k);
}
void rangemodify(int k, int l, int r, int ql, int qr, int w) {
    if (l == r) {
        if ((tr[k].st & (1ll << id[w]))) {
            tr[k].sum *= w;
            tr[k].sum %= mod;
        } else {
            tr[k].sum *= w - 1;
            tr[k].sum %= mod;
        }
        tr[k].st |= (1 << id[w]);
        return ;
    }
    if (l >= ql && r <= qr) {
        //关键,只有该区间完全包含这个质因子才retrun,否则继续递归下去
        if ((tr[k].st & (1ll << id[w]))) {
            tr[k].mul = tr[k].mul * w % mod;
            tr[k].sum = tr[k].sum * w % mod;
            return ;
        }
    }
    pushdown(k);
    int mid = l + r >> 1;
    if (ql <= mid) rangemodify(k + k, l, mid, ql, qr, w);
    if (qr > mid) rangemodify(k + k + 1, mid + 1, r, ql, qr, w);
    pushup(k);
}
int rangesum(int k, int l, int r, int ql, int qr) {
    if (l >= ql && r <= qr) return tr[k].sum;
    pushdown(k);
    int mid = l + r >> 1;
    int res = 0;
    if (ql <= mid) res += rangesum(k + k, l, mid, ql, qr), res %= mod;
    if (qr > mid) res += rangesum(k + k + 1, mid + 1, r, ql, qr), res %= mod;
    return res;
}
void solve() {
    sieve(110);
    for (int i = 1; i <= 100; i ++) {
        int x = i;
        for (auto p : primes) {
            if (i < p) continue;
            while (x % p == 0) {
                x /= p;
                fac[i].emplace_back(p);
            }
        }
        if (x > 1) fac[i].emplace_back(x);
    }

    int n, m;
    cin >> n >> m;
    for (int i = 1; i <= n; i ++) {
        cin >> a[i];
    }
    build(1, 1, n);

    while (m --) {
        int op;
        cin >> op;
        if (op) {
            int l, r;
            cin >> l >> r;
            cout << rangesum(1, 1, n, l, r) << "\n";
        } else {
            int l, r, w;
            cin >> l >> r >> w;
            for (auto x : fac[w]) {
                rangemodify(1, 1, n, l, r, x);
            }
        }   
    }
}
signed main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int t;
    t = 1;
    // std::cin >> t;

    while (t--) {
        solve();
    }

    return 0;
}