题解 Luogu P4248 [AHOI2013]差异

cjoierzdc / 2023-06-23 / 原文

这是一个 SAM 做法。


显然只要求 \(\sum\limits_{1\le i < j \le n}\operatorname{lcp}(i,j)\)

考虑 \(T_i, T_j\) 在 SAM 上的两条链。显然,这两条链可以被表示为 \(1 \rightarrow P, P \rightarrow x, P \rightarrow y\) 的形式,其中 \(x,y\)\(T_i, T_j\) 在 SAM 上的终止节点。

枚举 \(P\),算 \(P\) 对答案的贡献。分两步讨论。

  • \(1 \rightarrow P\),这部分数量就是 \(1 \rightarrow P\) 的路径数量,可以一遍拓扑排序求出。

  • \(P \rightarrow x, y\),令 \(cnt\)\(P\) 到终止节点的方案数,显然这一部分是 \(C_{cnt}^2\)\(cnt\) 也可拓扑求出。把两者相乘就是 \(P\) 的贡献。

时间复杂度 \(O(|\sum|n)\),瓶颈在构造 SAM。

#include <bits/stdc++.h>
using namespace std;
void file(){
	freopen("1.in", "r", stdin);
	freopen("1.out", "w", stdout);
}
using ll = long long;
using ull = unsigned long long;
const int nMax = 5e5 + 5, nCnt = (nMax << 1), kMax = 26;
struct SAM{
	int lst, cnt;
	array<int, nCnt> len, fa, in, d;
	array<ll, nCnt> dp, ct;
	array<array<int, kMax>, nCnt> nxt;
	array<vector<int>, nCnt> G;
	SAM(){
		lst = cnt = 1;
		dp.fill(0);
		ct.fill(0);
		len.fill(0);
		fa.fill(0);
		in.fill(0);
		d.fill(0);
		for(auto &k : nxt) k.fill(0); 
	}
	void insert(char c){
		int k = c - 'a', cur = ++cnt, p = lst, q;
		len[cur] = len[p] + 1;
		lst = cur;
		while(p && (!nxt[p][k])){
			nxt[p][k] = cur;
			p = fa[p];
		}
		if(!p){
			fa[cur] = 1;
			return;
		}
		q = nxt[p][k];
		if(len[q] == len[p] + 1){
			fa[cur] = q;
			return;
		}
		int t = ++cnt;
		len[t] = len[p] + 1;
		nxt[t] = nxt[q];
		fa[t] = fa[q];
		fa[q] = fa[cur] = t;
		while(p && (nxt[p][k] == q)){
			nxt[p][k] = t;
			p = fa[p];
		}
	}
	void getg(){
		for(int i = lst; i; i = fa[i]) dp[i] = 1;
		ct[1] = 1;
		for(int i = 1; i <= cnt; i++)
			for(int to : nxt[i])
				if(to){
					G[to].emplace_back(i);
					in[i]++;
					d[to]++;
				}
	}
	void topo(){
		getg();
		queue<int> q;
		for(int i = 1; i <= cnt; i++)
			if(!in[i]) q.emplace(i);
		while(q.size()){
			int cur = q.front();
			q.pop();
			for(int to : G[cur]){
				dp[to] += dp[cur];
				if(!(--in[to]))
					q.emplace(to);
			}
		}
		for(int i = 1; i <= cnt; i++)
			if(!d[i]) q.emplace(i); 
		while(q.size()){
			int cur = q.front();
			q.pop();
			for(int to : nxt[cur]){
				ct[to] += ct[cur];
				if(!(--d[to]))
					q.emplace(to); 
			}
		}
	}
}T;
int main(){
//	file();
	ios::sync_with_stdio(0);
	string s;
	cin >> s;
	int n = s.size();
	for(char c : s) T.insert(c);
	T.topo();
	ll res = 0;
	for(int i = 1; i <= n; i++)
		res += 1ll * (n - i) * i + 1ll * (n + i + 1) * (n - i) / 2;
	for(int i = 2; i <= T.cnt; i++)
		res -= T.dp[i] * (T.dp[i] - 1) * T.ct[i]; 
	cout << res << endl; 
	return 0;
}