CF1249F Maximum Weight Subset 题解 / 长链剖分复习

佚名 / 2025-01-15 / 原文

CF1249F Maximum Weight Subset 题解

题目大意

给定一个 \(n\) 个节点的树,每个节点有点权 \(a_i\)。从中选出若干节点,要求这些点两两之间距离大于给定常数 \(k\),使得点权和最大。

Solve

给出一种线性做法。前置知识:长链剖分优化 DP。

考虑一个 DP:设 \(f(u,d)\) 表示在 \(u\) 的子树里选点,选出的点距离 \(u\) 号点的最短距离为 \(d\),这种情况下的最大点权和。暴力转移是简单的:

\[f(u,\min(j,i+1))\longleftarrow f(u,j)+f(v,i) \]

总复杂度为 \(O(n^3)\)

第二维和深度(距离)有关,容易长链剖分优化到 \(O(n^2)\),即 \(u\) 号节点直接继承其重儿子的 \(f\) 值,对于轻儿子,枚举上式种的 \(i,j\) 转移。复杂度约为 \(O(n^2)\)。这部分的代码如下,用指针实现:

void dp(int u,int fa)
{
	if(son[u])	dp(son[u],u);//先遍历重儿子
	f[u][0]=a[u];
	for(int i=k+1;i<h[u];i=-~i)
		f[u][0]=max(f[u][0],f[u][i]+a[u]);
	for(int v:e[u])
		if(v!=fa&&v!=son[u])
		{
			dp(v,u);
			for(int i=0;i<h[u];i=-~i)	now[i]=f[u][i];//用上个版本的
			for(int i=0;i<h[v];i=-~i)//h[v] 为 v 节点的链长
			{
				f[u][i+1]=max(f[u][i+1],f[v][i]);
				for(int j=max(k-i,0ll);j<h[u];j=-~j)
					f[u][min(i+1,j)]=max(f[u][min(i+1,j)],now[j]+f[v][i]);
			}
		}
}

瓶颈在于对 \(u\) 所在链长的枚举。所以我们考虑对 \(\min(i+1,j)\) 讨论,因为它的最大值为 \(v\) 所在链长,均摊是 \(n\) 级别的。

  • \(j\leq i\)\(\min(i+1,j)=j\),此时 \(1\leq j\leq h(v)\),可以暴力枚举 \(j\),那么可以对这个 \(j\) 贡献的 \(i\) 需要满足 \(j\leq i<h(v),i+j+1>k\iff i\geq k-j\),这是一段后缀,所以我们可以维护 \(g(v)\)\(f(v)\) 的后缀最大值,就有转移:

\[f(u,j)\longleftarrow f(u,j)+g(v,\max(j,k-j)) \]

  • \(j>i\)\(\min(i+1,j)=i+1\),由于本身就有 \(0\leq i<h(v)\),此时我们可以枚举 \(i\),那么可以对这个 \(i\) 贡献的 \(j\) 需要满足 \(i<j<h(u),i+j+1>k\iff j\geq k-i\),有转移:

\[f(u,i+1)\longleftarrow g(u,\max(i+1,k-i))+f(v,i) \]

至此,总复杂度为 \(O(\sum链长)=O(n)\)

Code

#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read()
{
	short f=1;
	int x=0;
	char c=getchar();
	while(c<'0'||c>'9')	{if(c=='-')	f=-1;c=getchar();}
	while(c>='0'&&c<='9')	x=(x<<1)+(x<<3)+(c^48),c=getchar();
	return x*f;
}
const int N=2e5+10;
int n,k,a[N];
vector<int>e[N];
int mem[2][N]/*内存池*/,*f[N],*g[N],h[N],son[N],tim;
void get_son(int u,int fa)
{
	h[u]=1;
	for(int i:e[u])
		if(i!=fa)
			get_son(i,u),h[u]=max(h[u],h[i]+1);
	if(h[u]>h[son[fa]])	son[fa]=u;
}
void get_chain(int u,int fa)//分配内存
{
	tim=-~tim;
	f[u]=mem[0]+tim;g[u]=mem[1]+tim;
	if(son[u])	get_chain(son[u],u);
	for(int i:e[u])
		if(i!=fa&&i!=son[u])	get_chain(i,u);
}
void dp(int u,int fa)
{
	if(son[u])	dp(son[u],u);
	f[u][0]=a[u]+(k+1<h[u]?g[u][k+1]:0);
	g[u][0]=max(f[u][0],h[u]>1?g[u][1]:0);
	for(int v:e[u])
		if(v!=fa&&v!=son[u])
		{
			dp(v,u);
			for(int i=0;i<h[v];i=-~i)
				if(max(i,k-i)<h[v])
					f[u][i]=max(f[u][i],f[u][i]+g[v][max(i,k-i)]);
			for(int i=0;i<h[v];i=-~i)
			{
				f[u][i+1]=max(f[u][i+1],f[v][i]);
				if(max(i+1,k-i)<h[u])
					f[u][i+1]=max(f[u][i+1],f[v][i]+g[u][max(i+1,k-i)]);
			}
			for(int i=h[v];i>=0;i=~-i)
				g[u][i]=max(i+1<h[u]?g[u][i+1]:0,f[u][i]);
		}
}
signed main()
{
	n=read();k=read();
	for(int i=1;i<=n;i=-~i)	a[i]=read();
	for(int i=1,u,v;i<n;i=-~i)
		u=read(),v=read(),
		e[u].push_back(v),e[v].push_back(u);
	get_son(1,0);get_chain(1,0);dp(1,0);
	return printf("%lld",*max_element(f[1],f[1]+h[1])),0;
}