PNR#7 T2 排列计数 题解
记 \(a\) 重排得到 \(b\)。
首先把原序列重排是不影响答案的,于是我们把原序列划分为若干极长的公差为 \(k\) 的等差子序列,又发现我们其实只关心每段子序列的长度,所以设子序列有 \(m\) 段,第 \(i\) 段的长度为 \(v_i\)。
考虑容斥。钦定重排之后有 \(t\) 个位置 \(i\) 满足 \(|b_i-b_{i-1}|=k\)。我们把满足这样的位置全部缩成一段,于是序列 \(b\) 剩下了 \((n-t)\) 段。设 \(f_{i,j}\) 表示考虑到前 \(i\) 个等差序列,当前的 \(b\) 序列里有 \(j\) 段的方案数,这里段之间是有序的,最后乘上 \(j!\) 即可。
对于 \(f\) 的转移,枚举把第 \(i\) 个等差序列分成 \(d\) 段,但是注意到这里长度 \(\ge 2\) 的段可以 reverse,所以还需要预处理一个 \(g_{i,j,1/2}\) 表示把长为 \(i\) 的等差序列划分为 \(j\) 段,且第 \(j\) 段的长度 \(=1/\ge 2\)的方案数,转移是 \(g_{i,j,1}=g_{i-1,j-1,1}+g_{i-1,j-1,2},g_{i,j,2}=2g_{i-1,j,1}+g_{i-1,j,2}\)。\(f\) 有转移 \(f_{i,j}=f_{i-1,j-d}\times (g_{v_i,d,1}+g_{v_i,d,2})\)。
最后答案为 \(\sum_{j=1}^{n}(-1)^{n-j}f_{m,j}j!\)。
#include<bits/stdc++.h>
bool Mst;
#define rep(x,qwq,qaq) for(int x=(qwq);x<=(qaq);++x)
#define per(x,qwq,qaq) for(int x=(qwq);x>=(qaq);--x)
using namespace std;
template <int MOD>
struct modint {
int val;
static int norm(const int& x) {
return x < 0 ? x + MOD : x;
}
static constexpr int get_mod() {
return MOD;
}
modint() : val(0) {}
modint(const int& m) : val(norm(m)) {}
modint(const long long& m) : val(norm(m % MOD)) {}
modint operator-() const {
return modint(norm(-val));
}
bool operator==(const modint& o) {
return val == o.val;
}
bool operator<(const modint& o) {
return val < o.val;
}
modint& operator+=(const modint& o) {
return val = (1ll * val + o.val) % MOD, *this;
}
modint& operator-=(const modint& o) {
return val = norm(1ll * val - o.val), *this;
}
modint& operator*=(const modint& o) {
return val = static_cast<int>(1ll * val * o.val % MOD), *this;
}
modint& operator/=(const modint& o) {
return *this *= o.inv();
}
modint& operator^=(const modint& o) {
return val ^= o.val, *this;
}
modint& operator>>=(const modint& o) {
return val >>= o.val, *this;
}
modint& operator<<=(const modint& o) {
return val <<= o.val, *this;
}
modint operator-(const modint& o) const {
return modint(*this) -= o;
}
modint operator+(const modint& o) const {
return modint(*this) += o;
}
modint operator*(const modint& o) const {
return modint(*this) *= o;
}
modint operator/(const modint& o) const {
return modint(*this) /= o;
}
modint operator^(const modint& o) const {
return modint(*this) ^= o;
}
bool operator!=(const modint& o) {
return val != o.val;
}
modint operator>>(const modint& o) const {
return modint(*this) >>= o;
}
modint operator<<(const modint& o) const {
return modint(*this) <<= o;
}
friend std::istream& operator>>(std::istream& is, modint& a) {
long long v;
return is >> v, a.val = norm(v % MOD), is;
}
friend std::ostream& operator<<(std::ostream& os, const modint& a) {
return os << a.val;
}
friend std::string tostring(const modint& a) {
return std::to_string(a.val);
}
template <typename T>
friend modint qpow(const modint a, const T& b) {
assert(b >= 0);
modint x = a, res = 1;
for (T p = b; p; x *= x, p >>= 1)
if (p & 1) res *= x;
return res;
}
modint inv() const {
return qpow(*this,MOD-2);
}
};
using M998 = modint<998244353>;
using mint = M998;
#define maxn 5010
#define mod 998244353
template<typename Tp>
mint qp(mint x,Tp y) {
assert(y>=0);
mint res=1;
while(y) {
if(y&1)res=res*x;
x=x*x;
y>>=1;
}
return res;
}
mint inv(mint x) {
return qp(x,mod-2);
}
struct Combinatorics {
#define Lim 2000000
mint fac[Lim+10],invfac[Lim+10];
Combinatorics() {
fac[0]=invfac[0]=1;
rep(i,1,Lim)fac[i]=fac[i-1]*i;
invfac[Lim]=inv(fac[Lim]);
per(i,Lim-1,1)invfac[i]=invfac[i+1]*(i+1);
}
mint C(int n,int m){
if(n<m||n<0||m<0)return 0;
return fac[n]*invfac[m]*invfac[n-m];
}
mint A(int n,int m){
if(n<m||n<0||m<0)return 0;
return fac[n]*invfac[n-m];
}
} comb;
int n,m,k;
int a[maxn];
int fa[maxn];
mint f[maxn][maxn];
mint g[maxn][maxn][3];//1 -> 1 2-> >=2
int V[maxn];
bool Med;
signed main() {
cerr<<(&Mst-&Med)/1024.0/1024.0<<" MB\n";
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>n>>k;
g[1][1][1]=1;
rep(i,2,n){
rep(j,1,n){
g[i][j][1]=g[i-1][j-1][2]+g[i-1][j-1][1];
g[i][j][2]=g[i-1][j][1]*2+g[i-1][j][2];
}
}
rep(i,1,n)cin>>a[i];
map<int,int>mp;
int tot=0;
rep(i,1,n) {
mp[a[i]]=i;
int p=mp[a[i]-k];
if(p)fa[i]=fa[mp[a[i]-k]];
else fa[i]=++tot;
++V[fa[i]];
}
int P=0;
f[0][0]=1;
rep(i,1,tot) {
int s=V[i];
rep(j,0,P) { //f[i-1][j]->f[i][j+d]
rep(d,1,s) { //枚举分成 d 个块
f[i][j+d]+=f[i-1][j]*(g[s][d][1]+g[s][d][2]);
}
}
P+=s;
}
mint ans=0;
rep(i,1,n)ans+=f[tot][i]*comb.fac[i]*((n-i)&1?-1:1);
cout<<ans<<'\n';
return 0;
}