洛谷 P3523 题解

lzy20091001 / 2024-10-06 / 原文

洛谷 P3523 [POI2011] DYN-Dynamite

分析

二分答案,问题转化为:对于给定的 \(K\),选择尽可能少的节点,使得所有关键节点都被「覆盖」。

对于一个关键节点,「覆盖」的定义为:存在一个被选择的点与这个关键节点的距离不大于 \(K\)

方便起见,我们指定 \(1\) 号节点是这棵树的根节点。

我们使用树形 DP 的思路,自底而上、无后效性地求解。换句话说,显然地,当我们考虑一个节点 \(u\) 时,默认 \(u\) 的子树内已经达到最优。

对于每个节点 \(u\),我们维护两个信息:

  • \(f _ u\)\(u\) 的子树内离 \(u\) 最远的 未被覆盖的关键节点\(u\) 之间的距离。特别地,若 \(u\) 的子树内不存在 未被覆盖的关键节点,则 \(f _ u = -\infty\)

  • \(g _ u\)\(u\) 的子树内离 \(u\) 最近的 被选择的节点\(u\) 之间的距离。特别地,若 \(u\) 的子树内不存在 被选择的节点,则 \(g _ u = \infty\)

初始化 \(f _ u = - \infty, g _ u = \infty\),然后从 \(u\) 的儿子节点转移:

\[f _ u = \max _ {v \in \operatorname{son}(u)} \{ f _ v + 1 \}, \\ g _ u = \min _ {v \in \operatorname{son}(u)} \{ g _ v + 1 \}. \]

现在我们对儿子节点的信息进行汇总:

  • \(f _ u + g _ u \le K\),那么 \(u\) 的子树内所有关键节点都被覆盖了,更新 \(f _ u \gets -\infty\)
  • \(d _ u = 1\)\(u\) 本身是一个关键节点,则更新 \(f _ u \gets \max \{f _ u, 0\}\)

接下来决策要不要选择 \(u\),我们有贪心结论:\(u \ne 1\) 时,当且仅当 \(f _ u = K\) 时选择 \(u\)

略证:必要性显然;充分及最优性则是因为若 \(f _ u < K\) 我们就可以选择 \(u\) 的祖先,而 \(u\) 的祖先一定不比 \(u\) 劣,因为 \(u\) 的祖先可以覆盖更大的范围。

当然 $u = 1 $ 是个例外,根节点没有祖先,所以若 \(f _ 1 \ne -\infty\) 就一定要选 \(1\) 号节点。

如果选择了 \(u\),那么同样也要更新:

\[f _ u \gets -\infty, \\ g _ u \gets 0. \]

总结一下:在 \([0, n]\) 内二分 \(K\),每次树形 DP 求解最少需要选择几个点,若不大于 \(m\) 则判定为可行,反之亦然。

复杂度分析

二分复杂度为 \(\operatorname{O}(\log n)\),每次树形 DP 判定复杂度为 \(\operatorname{O}(n)\),总时间复杂度为 \(\operatorname{O}(n \log n)\)

实现

递归常数略大,可能需要卡常。

#include <iostream>
#include <vector>

const int N = 300'000, INF = 1e9;

int n, m, cnt;
bool d[N + 5];
int f[N + 5], g[N + 5];
std::vector<int> edge[N + 5];

void dfs(int u, int fa, int k)
{
    f[u] = -INF;
    g[u] = INF;
    for (auto v : edge[u])
        if (v != fa)
        {
            dfs(v, u, k);
            f[u] = std::max(f[u], f[v] + 1);
            g[u] = std::min(g[u], g[v] + 1);
        }
    if (f[u] + g[u] <= k)
        f[u] = -INF;
    if (d[u] && g[u] > k)
        f[u] = std::max(f[u], 0);
    if (f[u] == k)
    {
        f[u] = -INF;
        g[u] = 0;
        cnt++;
    }
}

bool check(int k)
{
    cnt = 0;
    dfs(1, 0, k);
    if (f[1] >= 0)
        cnt++;
    return cnt <= m;
}

int binary_search()
{
    int l = 0, r = n, res = 0;
    while (l <= r)
    {
        int mid = (l + r) / 2;
        if (check(mid))
        {
            r = mid - 1;
            res = mid;
        }
        else
            l = mid + 1;
    }
    return res;
}

int main()
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);

    std::cin >> n >> m;
    for (int i = 1; i <= n; i++)
        std::cin >> d[i];
    for (int i = 1; i < n; i++)
    {
        int u, v;
        std::cin >> u >> v;
        edge[u].push_back(v);
        edge[v].push_back(u);
    }

    std::cout << binary_search() << "\n";

    return 0;
}