浅谈斜率优化

untitled0 / 2023-07-06 / 原文

如果一个 DP 的转移方程可以写成 \(f_i=\underset{j<i}{\min\!/\!\max}\>\{f_j+a_i\times b_j+c_i+d_j\}+C\) 的形式,那么可以运用斜率优化。

不妨设转移是 \(\min\),忽略那个常数 \(C\),设 \(g_{i,j}=f_j+a_i\times b_j+c_i+d_j\),即 \(f_i=\min\limits_{j<i}g_{i,j}\),式子可以化为 \(f_j+d_j=-a_i\times b_j+g_{i,j}-c_i\),设 \(y_j=f_j+d_j\)\(k=-a_i\)\(x_j=b_j\)\(t_j=g_{i,j}-c_i\),原式化为 \(y_j=kx_j+t_j\quad(*)\),这是一个一次函数的形式。

假设 \(f_i\) 是由 \(p\) 转移来的,即 \(f_i=g_{i,p}=\min_{j<i}g_{i,j}\),因为 \(t_j=g_{i,j}-c_i\),所以 \(t_p=\min_{j<i}t_j\)。 注意到 \((*)\) 式中 \(k\) 是一个定值,这说明,如果过每个点 \((x_j,y_j)\) 画斜率为 \(k\) 的直线 \(l_j\),则 \(l_p\)\(y\) 轴的截距是最小的,直观地说就是“在最下面”的。

(如图,假设有这些点,我们要画一条斜率为 \(-1\) 的直线(\(k=-1\)),则图中那条是最优的,其 \(y\) 轴截距是 \(5\),最小)

现在考虑如何快速找到这条最优直线:维护这些点的下凸壳,则与这个凸壳相切的直线是最优的。(这里“相切”的定义是,有且仅有一个交点 或 有无数个交点)

(如图,两种相切)

维护这个东西需要动态凸包,但是一般情况下并不需要:

  • 如果 \(x\) 单调,\(k\) 也单调,则决策点 \(p\) 只会单向移动,单调队列维护即可。
  • 如果 \(x\) 单调,用单调栈维护凸壳,然后二分即可。
  • 否则需要动态凸包 / 李超线段树。(我不会)

例题

  1. P5785 [SDOI2012] 任务安排

其实我觉得这个 \(n^2\) DP 挺难想到的。。。

\[f_i=\min_{j<i}\{f_j+t_i\times(c_i-c_j)+s\times(c_n-c_j)\} \]

其中 \(t\)\(c\) 是原题中 \(T\)\(C\) 的前缀和。提前计算了启动机器的代价。

注意不要用 double 表示斜率,容易因为精度 WA,把斜率式子化成乘法形式。

#include<bits/stdc++.h>
#define endl '\n'
#define rep(i, s, e) for(int i = s, i##E = e; i <= i##E; ++i)
#define per(i, s, e) for(int i = s, i##E = e; i >= i##E; --i)
#define F first
#define S second
#define int ll
#define gmin(x, y) (x = min(x, y))
#define gmax(x, y) (x = max(x, y))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double f128;
typedef pair<int, int> pii;
constexpr int N = 3e5 + 5;
int n, s, c[N], t[N], f[N];
int stk[N], tp;
// f[i] = min { f[j] + t[i] * (c[i] - c[j]) + s * (c[n] - c[j]) }
// f[i] = min { f[j] + t[i] * c[i] - t[i] * c[j] + s * c[n] - s * c[j] }
// f[j] - s * c[j] = t[i] * c[j] + f[i] - t[i] * c[i]
// 下凸,斜率单增
inline int Y(int i) { return f[i] - s * c[i]; }
inline int X(int i) { return c[i]; }
inline int K(int i) { return t[i]; }
inline int find(int k) {
    int l = 1, r = tp;
    while(l < r) {
        int mid = (l + r) / 2;
        if(Y(stk[mid]) - Y(stk[mid + 1]) <= k * (X(stk[mid]) - X(stk[mid + 1])))
            r = mid;
        else l = mid + 1;
    }
    return stk[l];
}
signed main() {
#ifdef ONLINE_JUDGE
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);
#endif
    cin >> n >> s;
    rep(i, 1, n) {
        cin >> t[i] >> c[i];
        t[i] += t[i - 1], c[i] += c[i - 1];
    }
    stk[++tp] = 0;
    rep(i, 1, n) {
        int p = find(K(i));
        f[i] = f[p] + t[i] * (c[i] - c[p]) + s * (c[n] - c[p]);
        while(tp > 1 && (Y(stk[tp]) - Y(stk[tp - 1])) * (X(i) - X(stk[tp])) >=
            (Y(i) - Y(stk[tp])) * (X(stk[tp]) - X(stk[tp - 1])))
            --tp;
        stk[++tp] = i;
    }
    cout << f[n] << endl;
    return 0;
}