【题解】P3401 洛谷树

Starrykiller / 2024-01-22 / 原文

先考虑子问题:给定序列 \(\left\{a_i\right\}\),多次询问给定 \([l,r]\),求

\[\sum_{l\leq p\leq q\leq r} \bigoplus _{i=p}^q a_i \]

其中 \(\oplus\) 表示按位异或运算。

考虑拆位,这样就只需要考虑 \(\texttt{01}\) 串的问题了。

考虑用线段树维护,具体地,我们在一个 node 上维护以下信息。

以下设这个节点管辖区间为 \([l,r]\)

  • \(val\)\(\sum_{l\leq p\leq q\leq r} \bigoplus _{i=p}^q a_i\)
  • \(ed[0/1]\):有多少个以 \(r\) 结尾的区间的 \(\operatorname{xor}\) 和为 \(0/1\)。形式化地说,就是 \(\sum_{l\leq p\leq r}[\left\{\bigoplus_{i=p}^{r} a_i\right\}=0/1]\)
  • \(st[0/1]\):有多少个以 \(l\) 开头的区间的 \(\operatorname{xor}\) 和为 \(0/1\)。形式化地说,就是 \(\sum_{l\leq q\leq r}[\left\{\bigoplus_{i=l}^{q} a_i\right\}=0/1]\)
  • \(sum\)\([l,r]\) 的异或和。形式化地说,就是 \(\bigoplus_{i=l}^r a_i\)

信息合并:

  • \(val=val_l+val_r+ed_l[1]\cdot st_r[0]+ed_l[0]\cdot st_r[1]\)
  • \(ed[0]=ed_r[0]+ed_l[sum_r]\)
  • \(ed[1]=ed_r[1]+ed_l[1\operatorname{xor}sum_r]\)
  • \(st[0]=st_l[0]+st_r[sum_l]\)
  • \(st[1]=st_l[1]+st_r[1\operatorname{xor}sum_l]\)
  • \(sum=sum_l\operatorname{xor}sum_r\)

正确性显然。

于是对于树链的情况我们也采用类似的做法即可。边权转点权是平凡的。

需要注意的是,\(\mathrm{merge}\) 操作不满足交换律。所以需要注意下最后拆出来的两段(\(u\to \operatorname{lca}(u,v)\)\(\operatorname{lca}(u,v)\to v\))(钦定右边不含左端点)应该如何合并。

然后就做完了。时间复杂度 \(\mathcal{\Theta}(n\log^2 n\log V)\),其中 \(\log V=10\)。(所以出题人才会 \(n\) 开到 \(3\times 10^4\)。)

现在是 1:01。我看我什么时候写完。

现在是 1:29。交上去 WA \(\texttt{30pts}\) 了。怎么回事呢。

现在是 1:34。AC \(\texttt{100pts}\)。原因是向上跳的时候信息合并想假了。

先睡了,早睡早起好身体。

#include <bits/stdc++.h>

using namespace std;

#define int long long 
struct infonode_base {
    mutable int val, st[2], ed[2], sum;
    infonode_base() { val=sum=st[0]=st[1]=ed[0]=ed[1]=0; }
    infonode_base(int v) {
        *this=infonode_base();
        st[val=sum=v]=ed[v]=1;
    }
    void reverse() { swap(st,ed); }
};

infonode_base operator+ (const infonode_base& a, const infonode_base& b) {
    infonode_base c;
    c.val=a.val+b.val+a.ed[1]*b.st[0]+
        a.ed[0]*b.st[1];
    c.ed[0]=b.ed[0]+a.ed[b.sum];
    c.ed[1]=b.ed[1]+a.ed[1^b.sum];
    c.st[0]=a.st[0]+b.st[a.sum];
    c.st[1]=a.st[1]+b.st[1^a.sum];
    c.sum=a.sum^b.sum;
    return c;
}

struct infonode {
    mutable infonode_base v[11];
    mutable int val;
    infonode() { val=0; }
    infonode(int x) {
        val=0;
        for (int i=0; i<=10; ++i)
            v[i]=infonode_base((x>>i)&1),
            val+=v[i].val*(1<<i);
    }
    void reverse() {
        for (int i=0; i<=10; ++i) v[i].reverse();
    }
};
infonode operator+(const infonode& a, const infonode& b) {
    infonode c;
    for (int i=0; i<=10; ++i)
        c.val+=(c.v[i]=a.v[i]+b.v[i]).val*(1<<i);
    return c;
}

constexpr int MAXN=3e4+10;

struct SG {
    #define ll(p) tr[p].l
    #define rr(p) tr[p].r
    #define ls(p) (p<<1)
    #define rs(p) (p<<1|1)
    #define val(p) tr[p].v
    struct {
        int l, r;
        infonode v;
    } tr[MAXN<<2];
    void pushup(int p) {
        val(p)=val(ls(p))+val(rs(p));
    }
    void build(int l, int r, int a[], int p=1) {
        ll(p)=l, rr(p)=r;
        if (l==r) {
            val(p)=infonode(a[l]);
            return;
        }
        int mid=(l+r)>>1;
        build(l,mid,a,ls(p)); build(mid+1,r,a,rs(p));
        pushup(p);
    }
    infonode query(int l, int r, int p=1) {
        int cl=ll(p), cr=rr(p);
        if (l<=cl&&cr<=r) return val(p);
        int mid=(cl+cr)>>1; infonode ans;
        if (l<=mid) ans=query(l,r,ls(p));
        if (r>mid) ans=ans+query(l,r,rs(p));
        return ans;
    }
    void change(int pos, int v, int p=1) {
        int cl=ll(p), cr=rr(p);
        if (cl==cr) return val(p)=infonode(v), void();
        int mid=(cl+cr)>>1;
        if (pos<=mid) change(pos,v,ls(p));
        else change(pos,v,rs(p));
        pushup(p);
    }
} sg;

int n,q;

struct {
    int nxt, to, w;
} e[MAXN<<1]; int head[MAXN], tot;
void add(int u, int v, int w) {
    e[++tot]={head[u],v,w}; head[u]=tot;
}

namespace hld {
    int dfn[MAXN], rnk[MAXN], dep[MAXN], fa[MAXN], top[MAXN], dnum,
        siz[MAXN], hson[MAXN], hw[MAXN];
    int b[MAXN];
    
    void dfs1(int u, int d) {
        dep[u]=d; siz[u]=1;
        for (int i=head[u]; i; i=e[i].nxt) {
            int v=e[i].to, w=e[i].w;
            if (dep[v]) continue;
            fa[v]=u; dfs1(v,d+1);
            siz[u]+=siz[v];
            if (siz[v]>siz[hson[u]]) hson[u]=v, hw[u]=w;
        }
    }

    void dfs2(int u, int t) {
        dfn[++dnum]=u; rnk[u]=dnum; top[u]=t;
        if (!hson[u]) return;
        dfs2(hson[u],t); b[rnk[hson[u]]]=hw[u];
        for (int i=head[u]; i; i=e[i].nxt) {
            int v=e[i].to, w=e[i].w;
            if (rnk[v]) continue;
            dfs2(v,v);
            b[rnk[v]]=w;
        }
    }

    void init() {
        dfs1(1,1); dfs2(1,1);
        sg.build(1,n,b);
    }

    void change(int u, int v, int x) {
        if (dep[u]<dep[v]) swap(u,v);
        sg.change(rnk[u],x);
    }

    int query(int u, int v) {
        infonode uu, vv;
        while (top[u]!=top[v]) {
            if (dep[top[u]]>dep[top[v]]) {
                auto cur=sg.query(rnk[top[u]],rnk[u]);
                uu=cur+uu; u=fa[top[u]];
            }        
            else {
                auto cur=sg.query(rnk[top[v]],rnk[v]);
                vv=cur+vv; v=fa[top[v]];
            }
        }
        if (u!=v) {
            if (dep[u]<dep[v]) {
                auto cur=sg.query(rnk[u]+1,rnk[v]);
                vv=cur+vv;
            }
            else {
                auto cur=sg.query(rnk[v]+1,rnk[u]);
                uu=cur+uu;
            }
        }
        uu.reverse();
        auto res=uu+vv;
        return res.val;         
    }
};
using hld::query; using hld::change;

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr); cout.tie(nullptr);
    cin>>n>>q;
    for (int i=1,u,v,w; i<n; ++i) {
        cin>>u>>v>>w;
        add(u,v,w), add(v,u,w);
    }
    hld::init();
    int op,u,v,x;
    while (q--) {
        cin>>op>>u>>v;
        if (op==2) {
            cin>>x; 
            change(u,v,x);
        }
        else cout<<query(u,v)<<'\n';
    }
}