abc370E Avoid K Partition

chenfy27的刷题记录 / 2024-10-07 / 原文

有长度为N的数组A[i]和整数K,需要将A划分成连续子数组,要求每个子数组之和不能为K。问有多少种方案,答案对998244353取模。

分析:如果不考虑和不为K的限制,就是个O(n^2)的dp,通过前缀和可以优化成O(n)。现要求子数组和不为K,可以用容斥思想先全部加上,然后减去不符合条件的部分。对于A[i],考虑j属于[0,i-1],该区间所有dp[j]之和记为sum,区间[j+1,i]之和不是K用前缀和来表示就是pre[i]-pre[j]!=K,移项得pre[j]!=pre[i]-K,可以用一个map来维护前缀和对应的dp和,便于容斥时相减。

#include <bits/stdc++.h>
using i64 = long long;

template<int MOD>
struct MInt {
    i64 x;
    int norm(i64 u) const {u%=MOD; if(u<0) u+=MOD; return u;}
    MInt(i64 v=0):x(norm(v)) {}
    int val() const {return x;}
    MInt operator-() const {return MInt(norm(MOD-x));}
    MInt inv() const {assert(x!=0); return power(MOD-2);}
    MInt &operator*=(const MInt &o) {x=norm(x*o.x); return *this;}
    MInt &operator+=(const MInt &o) {x=norm(x+o.x); return *this;}
    MInt &operator-=(const MInt &o) {x=norm(x-o.x); return *this;}
    MInt &operator/=(const MInt &o) {*this *= o.inv(); return *this;}
    friend MInt operator*(const MInt &a, const MInt &b) {MInt ans=a; ans*=b; return ans;}
    friend MInt operator+(const MInt &a, const MInt &b) {MInt ans=a; ans+=b; return ans;}
    friend MInt operator-(const MInt &a, const MInt &b) {MInt ans=a; ans-=b; return ans;}
    friend MInt operator/(const MInt &a, const MInt &b) {MInt ans=a; ans/=b; return ans;}
    friend std::istream &operator>>(std::istream &is, MInt &a) {i64 u; is>>u; a=MInt(u); return is;}
    friend std::ostream &operator<<(std::ostream &os, const MInt &a) {os<<a.val(); return os;}
    MInt power(i64 b) const {i64 r=1, t=x; while(b){if(b&1) r=r*t%MOD; t=t*t%MOD; b/=2;} return MInt(r);}
};
using mint = MInt<998244353>;

void solve() {
	i64 N, K;
	std::cin >> N >> K;
	std::vector<i64> A(N + 1), B(N + 1);
	for (int i = 1; i <= N; i++) {
		std::cin >> A[i];
	}
	std::partial_sum(A.begin(), A.end(), B.begin());

	mint sum = 0;
	std::map<i64,mint> cnt;
	std::vector<mint> dp(N + 1);
	dp[0] = cnt[0] = sum = 1;
	for (int i = 1; i <= N; i++) {
		dp[i] = sum - cnt[B[i] - K];
		sum += dp[i];
		cnt[B[i]] += dp[i];
	}
	std::cout << dp[N] << "\n";
}

int main() {
	std::cin.tie(0)->sync_with_stdio(0);
	int t = 1;
	while (t--) solve();
	return 0;
}