P9437 『XYGOI round1』一棵树 题解

Svemit / 2023-08-05 / 原文

赛时一眼换根 dp,然后调调调了大概 1h+。
题目传送门

什么是换根 dp

在大多数树形 dp 中,我们只考虑对根的贡献,而一部分题目需要算出对所有点的贡献,一个比较显然的做法是对每个点都跑一次树形 dp,但是大大增加了时间复杂度,是我们不能接受的。

树形 dp 中的换根 dp 问题又被称为二次扫描,通常不会指定根结点,并且根结点的变化会对一些值,例如子结点深度和、点权和等产生影响。

换根 dp 的常用套路钦定某一个点为根,然后从根开始进行两次 dfs,第一次记录自己子树内中的节点对自己的贡献,第二次记录不属于自己子树中的点对自己的贡献。

Solution

不妨钦定 1 为根。
考虑令 \(f_u = \sum\limits_{v \in subtree(u)} w(v, u)\)\(g_u = (\sum\limits_{v \notin subtree(u)} w(v, u)) + w_u\)。为什么要加上 \(w_u\)? 是为了方便转移(不加好像也不是很麻烦)。

首先我们要实现一个位数的函数。

LL mylog(LL v) {
	if(v == 0) return 10;
	int res = 0;
	while(v) res ++, v /= 10;
	return pow(10, res);
} 

计算 \(f\) 数组

考虑从儿子 \(v\) 转移到 \(u\),显然为 \(f_v \times mylog(w_u) + siz_v \times w_u\)

得到 \(f_u = (\sum\limits_{v \in subtree(u)} f_v) \times mylog(w_u) + w_u \times siz_u\)

void dfs1(int u, int fa) {
	siz[u] = 1;
	for(auto v : e[u])
		if(v != fa)	{
			dfs1(v, u);
			siz[u] += siz[v];
			f[u] = (f[u] + f[v]) % mod;
		}
	f[u] = (f[u] * mylog(w[u]) % mod + w[u] * siz[u] % mod) % mod;
}

计算 \(g\) 数组

如何计算子树外的点对自己的贡献呢?
不妨先想想从根节点转移到第二层节点。

\(1\) 转移到 \(u\)

先算出除 \(subtree(u)\) 外的点对 \(1\) 的贡献,显然为 \(f_1\) 减去 \(u\)\(1\) 的贡献 \(f_u * mylog(w_1) + siz_u * w_1\)

再考虑 \(u\) 对儿子 \(v\) 的贡献,为 \(f_u + g_u\) 减去 \(v\)\(u\) 的贡献,再把这些继续用计算 \(f\) 的方法转移到 \(v\) 上去。

void dfs2(int u, int fa) {
	for(auto v : e[u])
		if(v != fa)	{
			g[v] = ((f[u] + g[u] - val(v, u) - w[u]) % mod * mylog(w[v]) % mod + (siz[1] - siz[v] + 1) % mod * w[v] % mod) % mod;
			dfs2(v, u);
		}
}
AC code
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e6 + 5, INF = 0x3f3f3f3f;
const LL mod = 998244353;
int n, m;
LL w[N];
vector<int> e[N];
LL f[N], g[N], siz[N];
LL mylog(LL v) {
	if(v == 0) return 10;
	int res = 0;
	while(v) res ++, v /= 10;
	return pow(10, res);
} 
LL val(int u, int v) {
	return (f[u] * mylog(w[v]) + siz[u] * w[v]) % mod;
}
void dfs1(int u, int fa) {
	siz[u] = 1;
	for(auto v : e[u])
		if(v != fa)	{
			dfs1(v, u);
			siz[u] += siz[v];
			f[u] = (f[u] + f[v]) % mod;
		}
	f[u] = (f[u] * mylog(w[u]) % mod + w[u] * siz[u] % mod) % mod;
}
void dfs2(int u, int fa) {
	for(auto v : e[u])
		if(v != fa)	{
			g[v] = ((f[u] + g[u] - val(v, u) - w[u]) % mod * mylog(w[v]) % mod + (siz[1] - siz[v] + 1) % mod * w[v] % mod) % mod;
			dfs2(v, u);
		}
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n;
    for(int i = 1; i <= n; i ++) cin >> w[i];
    for(int i = 2; i <= n; i ++) {
    	int p;
    	cin >> p;
    	e[p].push_back(i), e[i].push_back(p);
  	}
  	dfs1(1, 0);
  	g[1] = w[1];
  	dfs2(1, 0);
  	LL res = 0;
  	for(int i = 1; i <= n; i ++)
  		res = (res + f[i] + g[i] - w[i]) % mod;
  	cout << res << '\n';
    return 0;
}