Solution - Codeforces 1628D2 Game on Sum (Hard Version)

rizynvu / 2024-10-19 / 原文

首先来考虑 Easy

注意到的是最后输出的答案要求是模意义下的,这说明对于实数二分的做法都已经用不了了。
注意到 \(n, m\le 3000\) 的数据范围,于是一个想法就是考虑 DP 之类的算法。

考虑到 B 选了 \(+ / -\) 实际上就代表着下一轮的 \(m\) 是否会 \(-1\),于是可以设状态为 \(f_{i, j}\) 表示还有 \(i\) 轮,还有 \(j\)\(+\) 要选的最优值。

那么边界值容易知道是 \(f_{i, 0} = 0, f_{i, i} = ik\),即如果只剩 \(-\) 那么肯定是 A 全选 \(0\),如果全剩 \(+\) 那么肯定是 A 全选 \(k\)
那么就来考虑 \(f_{i, j}(j\in (0, i))\) 的转移,考虑到 A 选了一个 \(x\) 后,B 会选 \(+ / -\),因为 B 想最小,于是 B 选出来的一定是 \(\min\{f_{i - 1, j - 1} + x, f_{i - 1, j} - x\}\)
因为 A 要最大化,相当于是 \(f_{i, j} = \max\{\min\{f_{i - 1, j - 1} + x, f_{i - 1, j} - x\}\}(x\in [0, k])\)
于是从函数图象分析,非常显然的是肯定是取 \(x = \frac{f_{i - 1, j} - f_{i - 1, j - 1}}{2}\) 最优,此时有 \(f_{i, j} = \frac{f_{i - 1, j - 1} + f_{i - 1, j}}{2}\)

其实上面还漏了一步,为什么这个 \(x\) 一定能被选出来?
那么实际上就是说明 \(f_{i - 1, j} - f_{i - 1, j - 1}\le 2k\),这个还是比较显然的,因为再劣也劣不过把 \(-k\) 改为 \(+k\)\(\Delta = 2k\)

于是 Easy 就做完了,时间复杂度 \(\mathcal{O}(nm)\)

接下来考虑 Hard

那么一个想法是 DP 转移的值实际上都是有 \(f_{i, i}\) 的边界条件推来的,于是可以对于每个 \((i, i)\) 单独统计对 \((n, m)\) 的贡献。

首先考虑到的是 DP 中从 \(i\to i + 1\) 就会有 \(\frac{1}{2}\) 的系数,所以首先就有个 \(\frac{1}{2^{n - i}}\) 的系数。
其次考虑到 \((i, i)\) 剩下时候就可以任意走 \((+1, +1)\) 或者 \((+1, 0)\),但是不能走到 \((j, j)\)
但是发现不能走到 \((j, j)\) 是好算的,因这个的充要条件就是走到了 \((i + 1, i + 1)\)
于是只需要钦定第一步走到是 \((+1, 0)\) 就可以了,方案数为 \(\binom{n - i - 1}{m - i}\)

于是就在 \(\mathcal{O}(n)\) 的时间复杂度做完了。

注意特判 \(n = m\)

#include<bits/stdc++.h>
using ll = long long;
constexpr ll mod = 1e9 + 7, inv2 = mod + 1 >> 1;
inline ll qpow(ll a, ll b, ll v = 1) {
   while (b)
      b & 1 && ((v *= a) %= mod), b >>= 1, (a *= a) %= mod;
   return v;
}
const int maxn = 1e6 + 10, N = 1e6;
ll fac[maxn], ifac[maxn], pw[maxn];
inline void init() {
   for (int i = fac[0] = 1; i <= N; i++) fac[i] = fac[i - 1] * i % mod;
   ifac[N] = qpow(fac[N], mod - 2);
   for (int i = N; i; i--) ifac[i - 1] = ifac[i] * i % mod;
   for (int i = pw[0] = 1; i <= N; i++) pw[i] = pw[i - 1] * inv2 % mod;
}
inline ll binom(int n, int m) {return fac[n] * ifac[m] % mod * ifac[n - m] % mod;}
inline void solve() {
   int n, m; ll k; scanf("%d%d%lld", &n, &m, &k);
   if (n == m) return printf("%lld\n", m * k % mod), void();
   ll ans = 0;
   for (int i = 1; i <= m; i++)
      (ans += 1ll * i * k % mod * binom(n - i - 1, m - i) % mod * pw[n - i]) %= mod;
   printf("%lld\n", ans);
}
int main() {
   init();
   int T; scanf("%d", &T);
   for (int id = 1; id <= T; id++) solve();
   return 0;
}