SAM/广义 SAM 非常偷懒的解释

SFlyer / 2024-03-19 / 原文

这里是模板题 P6139。

进行了一个广义 SAM 的学习。离线部分 OiWiki 讲得很好,但是在线部分没有。我做一些补充。

首先 SAM 是什么?简单来说,

  • SAM 是一个有向无环图,节点是状态,对应一个 endpos(\(S\) 子串 \(t\) 的结尾位置集合)。边标有字符。

  • SAM 是有 \(t_0\) 初始状态,若干个结尾状态,是 \(\mathcal{O}(n)\) 的,\(\le 4\) 常数。

  • 本质不同字串个数=路径条数,可以 dp。

  • 一个状态中,定义 \(len(v)\)\(t_0\)\(v\) 的最长路径长度。

  • 一个状态中,定义 \(fa(v)\)\(v\) 最长的后缀,使出现位置多余 \(v\)

  • endpos 要么无交要么包含,可以构成后缀树。同一个 endpos 对应的字符串长度是一个区间,\([len(fa(v))+1,len(v)]\)

我们当作你会 SAM 了(你去 OiWiki 看一下模板代码),广义 SAM 是什么呢?

首先看一下代码。

pos[0]=1;
for (int j=1; j<s.size(); j++){
	pos[j]=ins(s[j]-'a',pos[j-1]);
}

这个是主函数里的。因为是多个串,\(lst\) 是什么你只能是上一个的结尾,不能直接 ++tot 求得。

int ins(int c,int lst){
	// (1)
	if (t[lst].ch[c] && t[t[lst].ch[c]].len==t[lst].len+1){
		return t[lst].ch[c];
	}
	//
	int p=lst,np=++tot,fl=0;
	t[np].len=t[p].len+1;
	while (p && !t[p].ch[c]){
		t[p].ch[c]=np;
		p=t[p].fa;
	}
	if (!p){
		t[np].fa=1;
		return np;
	}
	int q=t[p].ch[c];
	if (t[q].len==t[p].len+1){
		t[np].fa=q;
		return np;
	}
	// (2)
	if (p==lst){
		fl=1;
		np=0;
		tot--;
	}
	//
	int nq=++tot;
	t[nq]=t[q];
	t[nq].len=t[p].len+1;
	t[q].fa=t[np].fa=nq;
	while (p && t[p].ch[c]==q){
		t[p].ch[c]=nq;
		p=t[p].fa;
	}
	// (3)
	return fl?nq:np;
	//
}

这个是拓展的部分。发现和普通的 SAM 相比,有 \(3\) 个不同的部分。为什么呢?

(1)处是因为如果有一个连续的转移,没必要新建了。

(2)处是因为:这个是有 \(c\) 的儿子,但是不是连续的,那么,我们只需要建新状态,不需要建 \(lst\rightarrow c\)。这个清空 \(np\) 然后 \(tot--\) 就可以了。

(3)处是因为:很容易理解,你删了 \(np\) 就得返回 \(nq\)

然后就做完了。

Code
#include <bits/stdc++.h>

using namespace std;

using ll = long long;

const int N = 2e6+6;

struct sam {
	int fa,len,ch[26];
} t[N];

int tot=1,pos[N];

int ins(int c,int lst){
	if (t[lst].ch[c] && t[t[lst].ch[c]].len==t[lst].len+1){
		return t[lst].ch[c];
	}
	int p=lst,np=++tot,fl=0;
	t[np].len=t[p].len+1;
	while (p && !t[p].ch[c]){
		t[p].ch[c]=np;
		p=t[p].fa;
	}
	if (!p){
		t[np].fa=1;
		return np;
	}
	int q=t[p].ch[c];
	if (t[q].len==t[p].len+1){
		t[np].fa=q;
		return np;
	}
	if (p==lst){
		fl=1;
		np=0;
		tot--;
	}
	int nq=++tot;
	t[nq]=t[q];
	t[nq].len=t[p].len+1;
	t[q].fa=t[np].fa=nq;
	while (p && t[p].ch[c]==q){
		t[p].ch[c]=nq;
		p=t[p].fa;
	}
	return fl?nq:np;
}

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);

	int n;
	cin>>n;
	for (int i=1; i<=n; i++){
		string s;
		cin>>s;
		s=" "+s;
		pos[0]=1;
		for (int j=1; j<s.size(); j++){
			pos[j]=ins(s[j]-'a',pos[j-1]);
		}
	}
	ll ans=0;
	for (int i=2; i<=tot; i++){
		ans+=t[i].len-t[t[i].fa].len;
	}
	cout<<ans<<"\n"<<tot<<"\n";
	return 0;
}