P9437 『XYGOI round1』一棵树 题解
赛时一眼换根 dp,然后调调调了大概 1h+。
题目传送门
什么是换根 dp
在大多数树形 dp 中,我们只考虑对根的贡献,而一部分题目需要算出对所有点的贡献,一个比较显然的做法是对每个点都跑一次树形 dp,但是大大增加了时间复杂度,是我们不能接受的。
树形 dp 中的换根 dp 问题又被称为二次扫描,通常不会指定根结点,并且根结点的变化会对一些值,例如子结点深度和、点权和等产生影响。
换根 dp 的常用套路钦定某一个点为根,然后从根开始进行两次 dfs,第一次记录自己子树内中的节点对自己的贡献,第二次记录不属于自己子树中的点对自己的贡献。
Solution
不妨钦定 1 为根。
考虑令 \(f_u = \sum\limits_{v \in subtree(u)} w(v, u)\),\(g_u = (\sum\limits_{v \notin subtree(u)} w(v, u)) + w_u\)。为什么要加上 \(w_u\)? 是为了方便转移(不加好像也不是很麻烦)。
首先我们要实现一个位数的函数。
LL mylog(LL v) {
if(v == 0) return 10;
int res = 0;
while(v) res ++, v /= 10;
return pow(10, res);
}
计算 \(f\) 数组
考虑从儿子 \(v\) 转移到 \(u\),显然为 \(f_v \times mylog(w_u) + siz_v \times w_u\)。
得到 \(f_u = (\sum\limits_{v \in subtree(u)} f_v) \times mylog(w_u) + w_u \times siz_u\) 。
void dfs1(int u, int fa) {
siz[u] = 1;
for(auto v : e[u])
if(v != fa) {
dfs1(v, u);
siz[u] += siz[v];
f[u] = (f[u] + f[v]) % mod;
}
f[u] = (f[u] * mylog(w[u]) % mod + w[u] * siz[u] % mod) % mod;
}
计算 \(g\) 数组
如何计算子树外的点对自己的贡献呢?
不妨先想想从根节点转移到第二层节点。
\(1\) 转移到 \(u\)。
先算出除 \(subtree(u)\) 外的点对 \(1\) 的贡献,显然为 \(f_1\) 减去 \(u\) 对 \(1\) 的贡献 \(f_u * mylog(w_1) + siz_u * w_1\)。
再考虑 \(u\) 对儿子 \(v\) 的贡献,为 \(f_u + g_u\) 减去 \(v\) 对 \(u\) 的贡献,再把这些继续用计算 \(f\) 的方法转移到 \(v\) 上去。
void dfs2(int u, int fa) {
for(auto v : e[u])
if(v != fa) {
g[v] = ((f[u] + g[u] - val(v, u) - w[u]) % mod * mylog(w[v]) % mod + (siz[1] - siz[v] + 1) % mod * w[v] % mod) % mod;
dfs2(v, u);
}
}
AC code
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e6 + 5, INF = 0x3f3f3f3f;
const LL mod = 998244353;
int n, m;
LL w[N];
vector<int> e[N];
LL f[N], g[N], siz[N];
LL mylog(LL v) {
if(v == 0) return 10;
int res = 0;
while(v) res ++, v /= 10;
return pow(10, res);
}
LL val(int u, int v) {
return (f[u] * mylog(w[v]) + siz[u] * w[v]) % mod;
}
void dfs1(int u, int fa) {
siz[u] = 1;
for(auto v : e[u])
if(v != fa) {
dfs1(v, u);
siz[u] += siz[v];
f[u] = (f[u] + f[v]) % mod;
}
f[u] = (f[u] * mylog(w[u]) % mod + w[u] * siz[u] % mod) % mod;
}
void dfs2(int u, int fa) {
for(auto v : e[u])
if(v != fa) {
g[v] = ((f[u] + g[u] - val(v, u) - w[u]) % mod * mylog(w[v]) % mod + (siz[1] - siz[v] + 1) % mod * w[v] % mod) % mod;
dfs2(v, u);
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n;
for(int i = 1; i <= n; i ++) cin >> w[i];
for(int i = 2; i <= n; i ++) {
int p;
cin >> p;
e[p].push_back(i), e[i].push_back(p);
}
dfs1(1, 0);
g[1] = w[1];
dfs2(1, 0);
LL res = 0;
for(int i = 1; i <= n; i ++)
res = (res + f[i] + g[i] - w[i]) % mod;
cout << res << '\n';
return 0;
}