PNR#7 T2 排列计数 题解

LazyBreeze / 2024-11-09 / 原文

\(a\) 重排得到 \(b\)

首先把原序列重排是不影响答案的,于是我们把原序列划分为若干极长的公差为 \(k\) 的等差子序列,又发现我们其实只关心每段子序列的长度,所以设子序列有 \(m\) 段,第 \(i\) 段的长度为 \(v_i\)

考虑容斥。钦定重排之后有 \(t\) 个位置 \(i\) 满足 \(|b_i-b_{i-1}|=k\)。我们把满足这样的位置全部缩成一段,于是序列 \(b\) 剩下了 \((n-t)\) 段。设 \(f_{i,j}\) 表示考虑到前 \(i\) 个等差序列,当前的 \(b\) 序列里有 \(j\) 段的方案数,这里段之间是有序的,最后乘上 \(j!\) 即可。

对于 \(f\) 的转移,枚举把第 \(i\) 个等差序列分成 \(d\) 段,但是注意到这里长度 \(\ge 2\) 的段可以 reverse,所以还需要预处理一个 \(g_{i,j,1/2}\) 表示把长为 \(i\) 的等差序列划分为 \(j\) 段,且第 \(j\) 段的长度 \(=1/\ge 2\)的方案数,转移是 \(g_{i,j,1}=g_{i-1,j-1,1}+g_{i-1,j-1,2},g_{i,j,2}=2g_{i-1,j,1}+g_{i-1,j,2}\)\(f\) 有转移 \(f_{i,j}=f_{i-1,j-d}\times (g_{v_i,d,1}+g_{v_i,d,2})\)

最后答案为 \(\sum_{j=1}^{n}(-1)^{n-j}f_{m,j}j!\)

#include<bits/stdc++.h>
bool Mst;
#define rep(x,qwq,qaq) for(int x=(qwq);x<=(qaq);++x)
#define per(x,qwq,qaq) for(int x=(qwq);x>=(qaq);--x)
using namespace std;
template <int MOD>
struct modint {
	int val;
	static int norm(const int& x) {
		return x < 0 ? x + MOD : x;
	}
	static constexpr int get_mod() {
		return MOD;
	}
	modint() : val(0) {}
	modint(const int& m) : val(norm(m)) {}
	modint(const long long& m) : val(norm(m % MOD)) {}
	modint operator-() const {
		return modint(norm(-val));
	}
	bool operator==(const modint& o) {
		return val == o.val;
	}
	bool operator<(const modint& o) {
		return val < o.val;
	}
	modint& operator+=(const modint& o) {
		return val = (1ll * val + o.val) % MOD, *this;
	}
	modint& operator-=(const modint& o) {
		return val = norm(1ll * val - o.val), *this;
	}
	modint& operator*=(const modint& o) {
		return val = static_cast<int>(1ll * val * o.val % MOD), *this;
	}
	modint& operator/=(const modint& o) {
		return *this *= o.inv();
	}
	modint& operator^=(const modint& o) {
		return val ^= o.val, *this;
	}
	modint& operator>>=(const modint& o) {
		return val >>= o.val, *this;
	}
	modint& operator<<=(const modint& o) {
		return val <<= o.val, *this;
	}
	modint operator-(const modint& o) const {
		return modint(*this) -= o;
	}
	modint operator+(const modint& o) const {
		return modint(*this) += o;
	}
	modint operator*(const modint& o) const {
		return modint(*this) *= o;
	}
	modint operator/(const modint& o) const {
		return modint(*this) /= o;
	}
	modint operator^(const modint& o) const {
		return modint(*this) ^= o;
	}
	bool operator!=(const modint& o) {
		return val != o.val;
	}
	modint operator>>(const modint& o) const {
		return modint(*this) >>= o;
	}
	modint operator<<(const modint& o) const {
		return modint(*this) <<= o;
	}
	friend std::istream& operator>>(std::istream& is, modint& a) {
		long long v;
		return is >> v, a.val = norm(v % MOD), is;
	}
	friend std::ostream& operator<<(std::ostream& os, const modint& a) {
		return os << a.val;
	}
	friend std::string tostring(const modint& a) {
		return std::to_string(a.val);
	}
	template <typename T>
	friend modint qpow(const modint a, const T& b) {
		assert(b >= 0);
		modint x = a, res = 1;
		for (T p = b; p; x *= x, p >>= 1)
			if (p & 1) res *= x;
		return res;
	}
	modint inv() const {
		return qpow(*this,MOD-2);
	}
};
using M998 = modint<998244353>;

using mint = M998;
#define maxn 5010
#define mod 998244353
template<typename Tp>
mint qp(mint x,Tp y) {
	assert(y>=0);
	mint res=1;
	while(y) {
		if(y&1)res=res*x;
		x=x*x;
		y>>=1;
	}
	return res;
}
mint inv(mint x) {
	return qp(x,mod-2);
}

struct Combinatorics {
#define Lim 2000000
	mint fac[Lim+10],invfac[Lim+10];
	Combinatorics() {
		fac[0]=invfac[0]=1;
		rep(i,1,Lim)fac[i]=fac[i-1]*i;
		invfac[Lim]=inv(fac[Lim]);
		per(i,Lim-1,1)invfac[i]=invfac[i+1]*(i+1);
	}
	mint C(int n,int m){
		if(n<m||n<0||m<0)return 0;
		return fac[n]*invfac[m]*invfac[n-m];
	}
	mint A(int n,int m){
		if(n<m||n<0||m<0)return 0;
		return fac[n]*invfac[n-m];
	}
} comb;
int n,m,k;
int a[maxn];
int fa[maxn];
mint f[maxn][maxn];
mint g[maxn][maxn][3];//1 -> 1 2-> >=2 
int V[maxn];
bool Med;
signed main() {
	cerr<<(&Mst-&Med)/1024.0/1024.0<<" MB\n";
	ios::sync_with_stdio(false);
	cin.tie(0),cout.tie(0);
	cin>>n>>k;
	g[1][1][1]=1;
	rep(i,2,n){
		rep(j,1,n){
			g[i][j][1]=g[i-1][j-1][2]+g[i-1][j-1][1];
			g[i][j][2]=g[i-1][j][1]*2+g[i-1][j][2];
		}
	}
	rep(i,1,n)cin>>a[i];
	map<int,int>mp;
	int tot=0;
	rep(i,1,n) {
		mp[a[i]]=i;
		int p=mp[a[i]-k];
		if(p)fa[i]=fa[mp[a[i]-k]];
		else fa[i]=++tot;
		++V[fa[i]];
	}
	int P=0;
	f[0][0]=1;
	rep(i,1,tot) {
		int s=V[i];
		rep(j,0,P) { //f[i-1][j]->f[i][j+d]
			rep(d,1,s) { //枚举分成 d 个块
			    f[i][j+d]+=f[i-1][j]*(g[s][d][1]+g[s][d][2]);
			}
	    }
	    P+=s;
	}
	mint ans=0;
	rep(i,1,n)ans+=f[tot][i]*comb.fac[i]*((n-i)&1?-1:1);
	cout<<ans<<'\n';
	return 0;
}