题解 Luogu P4248 [AHOI2013]差异
这是一个 SAM 做法。
显然只要求 \(\sum\limits_{1\le i < j \le n}\operatorname{lcp}(i,j)\)。
考虑 \(T_i, T_j\) 在 SAM 上的两条链。显然,这两条链可以被表示为 \(1 \rightarrow P, P \rightarrow x, P \rightarrow y\) 的形式,其中 \(x,y\) 为 \(T_i, T_j\) 在 SAM 上的终止节点。
枚举 \(P\),算 \(P\) 对答案的贡献。分两步讨论。
-
\(1 \rightarrow P\),这部分数量就是 \(1 \rightarrow P\) 的路径数量,可以一遍拓扑排序求出。
-
\(P \rightarrow x, y\),令 \(cnt\) 为 \(P\) 到终止节点的方案数,显然这一部分是 \(C_{cnt}^2\),\(cnt\) 也可拓扑求出。把两者相乘就是 \(P\) 的贡献。
时间复杂度 \(O(|\sum|n)\),瓶颈在构造 SAM。
#include <bits/stdc++.h>
using namespace std;
void file(){
freopen("1.in", "r", stdin);
freopen("1.out", "w", stdout);
}
using ll = long long;
using ull = unsigned long long;
const int nMax = 5e5 + 5, nCnt = (nMax << 1), kMax = 26;
struct SAM{
int lst, cnt;
array<int, nCnt> len, fa, in, d;
array<ll, nCnt> dp, ct;
array<array<int, kMax>, nCnt> nxt;
array<vector<int>, nCnt> G;
SAM(){
lst = cnt = 1;
dp.fill(0);
ct.fill(0);
len.fill(0);
fa.fill(0);
in.fill(0);
d.fill(0);
for(auto &k : nxt) k.fill(0);
}
void insert(char c){
int k = c - 'a', cur = ++cnt, p = lst, q;
len[cur] = len[p] + 1;
lst = cur;
while(p && (!nxt[p][k])){
nxt[p][k] = cur;
p = fa[p];
}
if(!p){
fa[cur] = 1;
return;
}
q = nxt[p][k];
if(len[q] == len[p] + 1){
fa[cur] = q;
return;
}
int t = ++cnt;
len[t] = len[p] + 1;
nxt[t] = nxt[q];
fa[t] = fa[q];
fa[q] = fa[cur] = t;
while(p && (nxt[p][k] == q)){
nxt[p][k] = t;
p = fa[p];
}
}
void getg(){
for(int i = lst; i; i = fa[i]) dp[i] = 1;
ct[1] = 1;
for(int i = 1; i <= cnt; i++)
for(int to : nxt[i])
if(to){
G[to].emplace_back(i);
in[i]++;
d[to]++;
}
}
void topo(){
getg();
queue<int> q;
for(int i = 1; i <= cnt; i++)
if(!in[i]) q.emplace(i);
while(q.size()){
int cur = q.front();
q.pop();
for(int to : G[cur]){
dp[to] += dp[cur];
if(!(--in[to]))
q.emplace(to);
}
}
for(int i = 1; i <= cnt; i++)
if(!d[i]) q.emplace(i);
while(q.size()){
int cur = q.front();
q.pop();
for(int to : nxt[cur]){
ct[to] += ct[cur];
if(!(--d[to]))
q.emplace(to);
}
}
}
}T;
int main(){
// file();
ios::sync_with_stdio(0);
string s;
cin >> s;
int n = s.size();
for(char c : s) T.insert(c);
T.topo();
ll res = 0;
for(int i = 1; i <= n; i++)
res += 1ll * (n - i) * i + 1ll * (n + i + 1) * (n - i) / 2;
for(int i = 2; i <= T.cnt; i++)
res -= T.dp[i] * (T.dp[i] - 1) * T.ct[i];
cout << res << endl;
return 0;
}