【题解】P3401 洛谷树
先考虑子问题:给定序列 \(\left\{a_i\right\}\),多次询问给定 \([l,r]\),求
其中 \(\oplus\) 表示按位异或运算。
考虑拆位,这样就只需要考虑 \(\texttt{01}\) 串的问题了。
考虑用线段树维护,具体地,我们在一个 node 上维护以下信息。
以下设这个节点管辖区间为 \([l,r]\)。
- \(val\):\(\sum_{l\leq p\leq q\leq r} \bigoplus _{i=p}^q a_i\)。
- \(ed[0/1]\):有多少个以 \(r\) 结尾的区间的 \(\operatorname{xor}\) 和为 \(0/1\)。形式化地说,就是 \(\sum_{l\leq p\leq r}[\left\{\bigoplus_{i=p}^{r} a_i\right\}=0/1]\)。
- \(st[0/1]\):有多少个以 \(l\) 开头的区间的 \(\operatorname{xor}\) 和为 \(0/1\)。形式化地说,就是 \(\sum_{l\leq q\leq r}[\left\{\bigoplus_{i=l}^{q} a_i\right\}=0/1]\)。
- \(sum\):\([l,r]\) 的异或和。形式化地说,就是 \(\bigoplus_{i=l}^r a_i\)。
信息合并:
- \(val=val_l+val_r+ed_l[1]\cdot st_r[0]+ed_l[0]\cdot st_r[1]\)。
- \(ed[0]=ed_r[0]+ed_l[sum_r]\)。
- \(ed[1]=ed_r[1]+ed_l[1\operatorname{xor}sum_r]\)。
- \(st[0]=st_l[0]+st_r[sum_l]\)。
- \(st[1]=st_l[1]+st_r[1\operatorname{xor}sum_l]\)。
- \(sum=sum_l\operatorname{xor}sum_r\)。
正确性显然。
于是对于树链的情况我们也采用类似的做法即可。边权转点权是平凡的。
需要注意的是,\(\mathrm{merge}\) 操作不满足交换律。所以需要注意下最后拆出来的两段(\(u\to \operatorname{lca}(u,v)\),\(\operatorname{lca}(u,v)\to v\))(钦定右边不含左端点)应该如何合并。
然后就做完了。时间复杂度 \(\mathcal{\Theta}(n\log^2 n\log V)\),其中 \(\log V=10\)。(所以出题人才会 \(n\) 开到 \(3\times 10^4\)。)
现在是 1:01。我看我什么时候写完。
现在是 1:29。交上去 WA \(\texttt{30pts}\) 了。怎么回事呢。
现在是 1:34。AC \(\texttt{100pts}\)。原因是向上跳的时候信息合并想假了。
先睡了,早睡早起好身体。
#include <bits/stdc++.h>
using namespace std;
#define int long long
struct infonode_base {
mutable int val, st[2], ed[2], sum;
infonode_base() { val=sum=st[0]=st[1]=ed[0]=ed[1]=0; }
infonode_base(int v) {
*this=infonode_base();
st[val=sum=v]=ed[v]=1;
}
void reverse() { swap(st,ed); }
};
infonode_base operator+ (const infonode_base& a, const infonode_base& b) {
infonode_base c;
c.val=a.val+b.val+a.ed[1]*b.st[0]+
a.ed[0]*b.st[1];
c.ed[0]=b.ed[0]+a.ed[b.sum];
c.ed[1]=b.ed[1]+a.ed[1^b.sum];
c.st[0]=a.st[0]+b.st[a.sum];
c.st[1]=a.st[1]+b.st[1^a.sum];
c.sum=a.sum^b.sum;
return c;
}
struct infonode {
mutable infonode_base v[11];
mutable int val;
infonode() { val=0; }
infonode(int x) {
val=0;
for (int i=0; i<=10; ++i)
v[i]=infonode_base((x>>i)&1),
val+=v[i].val*(1<<i);
}
void reverse() {
for (int i=0; i<=10; ++i) v[i].reverse();
}
};
infonode operator+(const infonode& a, const infonode& b) {
infonode c;
for (int i=0; i<=10; ++i)
c.val+=(c.v[i]=a.v[i]+b.v[i]).val*(1<<i);
return c;
}
constexpr int MAXN=3e4+10;
struct SG {
#define ll(p) tr[p].l
#define rr(p) tr[p].r
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
#define val(p) tr[p].v
struct {
int l, r;
infonode v;
} tr[MAXN<<2];
void pushup(int p) {
val(p)=val(ls(p))+val(rs(p));
}
void build(int l, int r, int a[], int p=1) {
ll(p)=l, rr(p)=r;
if (l==r) {
val(p)=infonode(a[l]);
return;
}
int mid=(l+r)>>1;
build(l,mid,a,ls(p)); build(mid+1,r,a,rs(p));
pushup(p);
}
infonode query(int l, int r, int p=1) {
int cl=ll(p), cr=rr(p);
if (l<=cl&&cr<=r) return val(p);
int mid=(cl+cr)>>1; infonode ans;
if (l<=mid) ans=query(l,r,ls(p));
if (r>mid) ans=ans+query(l,r,rs(p));
return ans;
}
void change(int pos, int v, int p=1) {
int cl=ll(p), cr=rr(p);
if (cl==cr) return val(p)=infonode(v), void();
int mid=(cl+cr)>>1;
if (pos<=mid) change(pos,v,ls(p));
else change(pos,v,rs(p));
pushup(p);
}
} sg;
int n,q;
struct {
int nxt, to, w;
} e[MAXN<<1]; int head[MAXN], tot;
void add(int u, int v, int w) {
e[++tot]={head[u],v,w}; head[u]=tot;
}
namespace hld {
int dfn[MAXN], rnk[MAXN], dep[MAXN], fa[MAXN], top[MAXN], dnum,
siz[MAXN], hson[MAXN], hw[MAXN];
int b[MAXN];
void dfs1(int u, int d) {
dep[u]=d; siz[u]=1;
for (int i=head[u]; i; i=e[i].nxt) {
int v=e[i].to, w=e[i].w;
if (dep[v]) continue;
fa[v]=u; dfs1(v,d+1);
siz[u]+=siz[v];
if (siz[v]>siz[hson[u]]) hson[u]=v, hw[u]=w;
}
}
void dfs2(int u, int t) {
dfn[++dnum]=u; rnk[u]=dnum; top[u]=t;
if (!hson[u]) return;
dfs2(hson[u],t); b[rnk[hson[u]]]=hw[u];
for (int i=head[u]; i; i=e[i].nxt) {
int v=e[i].to, w=e[i].w;
if (rnk[v]) continue;
dfs2(v,v);
b[rnk[v]]=w;
}
}
void init() {
dfs1(1,1); dfs2(1,1);
sg.build(1,n,b);
}
void change(int u, int v, int x) {
if (dep[u]<dep[v]) swap(u,v);
sg.change(rnk[u],x);
}
int query(int u, int v) {
infonode uu, vv;
while (top[u]!=top[v]) {
if (dep[top[u]]>dep[top[v]]) {
auto cur=sg.query(rnk[top[u]],rnk[u]);
uu=cur+uu; u=fa[top[u]];
}
else {
auto cur=sg.query(rnk[top[v]],rnk[v]);
vv=cur+vv; v=fa[top[v]];
}
}
if (u!=v) {
if (dep[u]<dep[v]) {
auto cur=sg.query(rnk[u]+1,rnk[v]);
vv=cur+vv;
}
else {
auto cur=sg.query(rnk[v]+1,rnk[u]);
uu=cur+uu;
}
}
uu.reverse();
auto res=uu+vv;
return res.val;
}
};
using hld::query; using hld::change;
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
cin>>n>>q;
for (int i=1,u,v,w; i<n; ++i) {
cin>>u>>v>>w;
add(u,v,w), add(v,u,w);
}
hld::init();
int op,u,v,x;
while (q--) {
cin>>op>>u>>v;
if (op==2) {
cin>>x;
change(u,v,x);
}
else cout<<query(u,v)<<'\n';
}
}