WNJXYK
Thanks to the cruel world.
WNJXYKのBlog
后缀自动机总结
后缀自动机总结

后缀自动机的构建

简单说明

从Hihocoder里偷一张图和一个表格过来,这是字符串aabbabd的后缀自动机示意图,蓝色的是转移边,虚线是Fail边。

https://blog.wnjxyk.cn/wp-content/uploads/2018/08/SAM.png

状态 子串 endpos
S 空串 {0,1,2,3,4,5,6}
1 a {1,2,5}
2 aa {2}
3 aab {3}
4 aabb,abb,bb {4}
5 b {3,4,6}
6 aabba,abba,bba,ba {5}
7 aabbab,abbab,bbab,bab {6}
8 ab {3,6}
9 aabbabd,abbabd,bbabd,babd,abd,bd,d {7}

我们可以发现,后缀自动机中的一个状态表示的是一系列结尾位置相同的字符串。而转移边链接着这些不同的状态,例如状态6状态7之间连着一条转移边,令状态6中所有的字符串添加一个字符b之后转移到状态7,因为状态6状态7的结尾位置集合大小相同,数值差1,所以可以认为他们两个状态是经过字符c转移之后相同的两个状态。

得到了后缀自动机,你就得到了这些东西:
1. 按照结束位置定义的本质不同的子串
2. 后缀数组(正向DFS即可)
3. S的反串的后缀树(需要将Fail边反向连接)

变量说明

  1. node[x].len 后缀自动机中节点最长能够接受的字符串的长度
  2. node[x].pos 后缀自动机中节点能接受字符串集合的最靠前的最末位置
  3. node[x].cnt 后缀自动机中节点的出现次数
  4. node[x].nxt 后缀自动机中的转移边
  5. node[x].fail 后缀自动机中的Fail边

构造方法

构建后缀自动机是一个在线的算法,通过不断向当前字符串的右边添加字符得到。
首先后缀自动机有一个初始节点S0表示空串,所以它的基本信息为S0.len=0, S0.pos=0, S0.fail=S0, S0.nxt[]=NULL
后缀自动机始终能够接受这个串本身,所以新加入一个字符,它肯定要出现在后缀自动机这个有向无环图的从S0节点开始的最长链的最末尾。所以,我们使用一个last变量,来记录当前最长链的最末尾是哪个节点。

当新加入一个字符c的时候,我们先增长最长链,保证后缀自动机能接受这个加入c之后的字符串,假设我们新建的这个节点是NP,最长链的末尾节点是P。此时NP的两个参数lenpos都应该为P.len+1,因为它们分别表示最长接受长度与最靠前的最末位置,当前这个字符在字符串的最末尾,所以位置与长度应该就是最长链的长度+1。

P = last
NP.pos=NP.len=P.len+1
NP.nxt[]=NULL
NP.cnt=1

此时,我们要开始更新NP.fail的值,它应当是能够接受状态P表示的字符串的后缀集合中下一个字符能接受c的节点。所以我们在节点P的后缀集合中找到一个能够接受c的节点P',再令QP'经过转移边nxt[c]之后得到的节点。
同时,P节点中表示字符串的后缀集合的状态中不包含P'节点表示字符串的后缀集合状态的所有状态的转移边均指向NP,表示这些节点也能够接受字符c

while(P && P.nxt[c]==0) P.nxt[c]=np, P=P.fail

接下来分情况讨论:

  1. 如果当前找不到一个节点Q,那么我们简单的令NP.fail=S0即可。

  2. 如果找到了这样的一个节点Q,那么接下来要讨论Q节点的状态。

    • 如果Q.len=P’.len+1,那么就说明节点Q的表示的字符串和节点P'中的是本质相同的。就是说,节点P'中的所有字符串在加入一个字符c之后,全部都转移到了节点Q。在这种情况下,我们只需要简单的令NP.fail=Q即可。

    • 如果Q.len \neq P’.len+1,那么Q节点表示的字符串的结尾位置集合不仅仅可以由节点P'转移过来,可以通过其他节点转移得到。此时,我们就不能简单的令NP.fail指向Q了。解决方法是,我们建立一个节点Q的拷贝Q',插入在状态P'Q之间,令Q'的所有状态都是由P'加入字符c转移得到,再令NP.fail=Q'

Q'=Q
Q'.len=P'.len+1
Q'.fail=Q.fail
Q.fail=NP.fail=Q'
while(P && P.nxt[c]==Q) P.nxt[c]=Q', P=P.fail

后缀自动机模版

namespace SAM{
    const int N_CHAR=26;
    const int MAXN=1000000+50;

    struct Node{
        int nxt[N_CHAR], fail;
        int len; // Max Length of State
        int pos; // Appear Position of State, Indexed From 1
        int cnt; // Appear Count of State
    }node[MAXN*2];
    int numn, last, root;

    /**
     * Create New Node for SAM
     */
    inline int newNode(int l, int p){
        int x=++numn;
        for (int i=0; i<N_CHAR; i++) node[x].nxt[i]=0;
        node[x].cnt=node[x].fail=0; 
        node[x].len=l; 
        node[x].pos=p;
        return x;
    }

    /**
     * Init SAM
     */
    inline void init(){ 
        root=last=newNode(numn=0, 0); 
    }

    /**
     * Add Char Into SAM
     */
    inline void addChar(int c){
        int p=last, np=newNode(node[p].len+1, node[p].len+1);
        while(p && node[p].nxt[c]==0) node[p].nxt[c]=np, p=node[p].fail;
        if (p==0) node[np].fail=root; else{
            int q=node[p].nxt[c];
            if (node[p].len+1 == node[q].len){
                node[np].fail=q;
            }else{
                int nq=newNode(node[p].len+1, node[q].pos);
                for (int i=0; i<N_CHAR; i++) node[nq].nxt[i]=node[q].nxt[i];
                node[nq].fail=node[q].fail;
                node[q].fail=node[np].fail=nq;
                while(p && node[p].nxt[c]==q) node[p].nxt[c]=nq, p=node[p].fail;
            }
        }
        last=np;  node[np].cnt=1;
    }

    /**
     * Update Appear Count of States
     */
    inline void updateCount(){
        static int deg[MAXN*2]; for (int i=0; i<=numn; i++) deg[i]=0;
        static queue<int> que; while(!que.empty()) que.pop();
        for (int i=1; i<=numn; i++) deg[node[i].fail]++;
        for (int i=1; i<=numn; i++) if (deg[i]==0) que.push(i);
        while(!que.empty()){
            int x=que.front(); que.pop();
            int v=node[x].fail;
            node[v].cnt+=node[x].cnt;
            if (--deg[v]==0) que.push(v);
        }
    }
}

相关题目

1. 查找子串及其所有出现位置

在模式串T中查找子串P的所有出现位置,复杂度O(len(P)+len(Answer))

对T建立后缀自动机,并且记录Position表示每个状态的第一次出现位置。当节点NP由P后添加字符新建得到,NP.Position=P.Len;当结点NQ由结点Q拷贝得到,NQ.Position=Q.Position

首先在自动机上跑出字符串P,跑到状态S(也是结点S)。如果此时找不到对应状态,则子串P不存在于T中。
子串P的所有出现位置就是S.Position,即所有沿着Fail边跳跃最终能到达结点S的结点P的P.Position。为了做到这一点,我们只需要在建立自动机之后将Fail边反向建图,然后从结点S沿着反向Fail边遍历即可。

题目链接:http://hihocoder.com/problemset/problem/1441

2. 查找某个子串的出现次数

在模式串T中查找子串P及其出现次数,复杂度O(len(P))

建立T后缀自动机,记录Count表示当前状态的出现次数,如果当前结点新建得到,Count=1;如果当前结点拷贝得到,Count=0
然后我们按照Fail边的反向拓扑序,更新Count变量,即对于结点P:P->Fail.Count += P.Count

然后对于子串,我们只需要在自动机上跑出其对应状态结点P,然后查询P.Count即可。

3. 求所有子串数量与和

求一个串T中所有不同子串的数量与某种和,复杂度O(len(T))

直接计算:
建立后缀自动机后,T的任意子串对应后缀自动机上的一种状态,求所有不同子串的数量即为求后缀自动机上不同状态包含的子串的数量之和。
很明显一个状态P包含不同子串的数量为:P.Len - P->fail.Len
但是这样无法计算子串的某种和。

顺序计算:
后缀自动机上不同状态包含的子串的数量之和也可以通过动态规划计算出来。
按照后缀自动机的转移边的拓扑序进行DP,令Dp[P]表示状态P包含的子串数量,可以知道初始状态Dp[root]=1,转移方程

Dp[P]=\sum\limits_{<x, P> \in SAM} Dp[x]

对于求某种和(比如子串长度),思路与求不同子串的数量相同。我们按照拓扑序进行DP,令Dp[P]表示状态P包含的子串数量、Ans[P]表示状态P中所有字符串的某种和。可以很容易知道初始状态同上,转移方程如下

\begin{aligned}
Dp[P]&=\sum\limits_{<x, P> \in SAM} Dp[x]\\
ans[P]&=\sum\limits_{<x, P> \in SAM} ans[x]+Dp[x]
\end{aligned}

最终答案为\sum DP[x]\sum ans[x]

逆序计算:
后缀自动机上不同状态包含的子串的数量等于后缀自动机上不同路径的数量。
所以我们可以按照后缀自动机的转移边的反向拓扑顺序进行DP,令Dp[P]表示从节点P开始到结束的不同的路径数量,那么显然转移方程为:

Dp[P] = 1 + \sum\limits_{<P, x> \in SAM} Dp[x]

对于求某种和(例如子串长度),我们可以以仿照如此。Ans[P]表示从节点P到结尾的所有子串的长度和,转移方程如下:

Ans[P] = \sum\limits_{<P, x> \in SAM} Dp[x]+Ans[x]

最终答案为Dp[Root]Ans[Root]

题目链接:http://hihocoder.com/problemset/problem/1457

4. 查询第K小的子串

对于这个问题,我们可以参考之前的逆序计算不同子串的思路。
我们发现求第K小的子串,即为求后缀自动机上第K小的路径,所以我们同样计算一个Dp[P]表示从P节点向后的所有不同路径数量

Dp[P] = 1 + \sum\limits_{<P, x> \in SAM} Dp[x]

那么,我们就可以枚举子串的每一位,通过Dp[P]与K的关系来确定。

其他题目

  1. http://hihocoder.com/problemset/problem/1465
    求一个字符串S及其循环串,在模式串T中的匹配次数。对T建立后缀自动机,将S倍长扩展成S’,然后在后缀自动机是进行匹配,每次只匹配长度为len(S)的状态。注意,当前匹配长度需要使用一个变量进行跟踪,并在跳Fail与匹配时进行更新。

  2. http://hihocoder.com/problemset/problem/1466 Hihocoder : 后缀自动机六·重复旋律9
    在某个字符串后添加一个字符使其还是原串的子串,这样是一个合法的操作。给定两个字符串A与B,求这两个字符串的两个子串A’与B’使两个人轮流选择一个字符串进行合法操作不能进行的输的规则下先手胜。求这样的字符串对<A’, B’>中第K大的。
    对于一个字符串,我们后缀自动机上的一个状态P的SG函数相同。不同状态按照转移边的Mex求SG值。两个字符串对SG值进行XOR,求得答案。
    此时问题转化为一个有限制的后缀自动机上查询第K小的问题。

#include <cstdio>
#include <cstring>
#include <set>
#include <queue>
#include <cstdlib>
#include <cassert>
#define LL long long
using namespace std;


const int MOD=1e9+7;

const int N_CHAR=30;
const int MAXN=100000+50;
struct SAM{
    struct Node{
        int nxt[N_CHAR], fail;
        int len; // Max Length of State
        int pos; // Appear Position of State, Indexed From 1
        int cnt; // Appear Count of State
    }node[MAXN*2];
    int numn, last, root;

    /**
     * Create New Node for SAM
     */
    inline int newNode(int l, int p){
        int x=++numn;
        for (int i=0; i<N_CHAR; i++) node[x].nxt[i]=0;
        node[x].cnt=node[x].fail=0; 
        node[x].len=l; 
        node[x].pos=p;
        return x;
    }

    /**
     * Init SAM
     */
    inline void init(){ 
        root=last=newNode(numn=0, 0); 
    }

    /**
     * Add Char Into SAM
     */
    inline void addChar(int c){
        int p=last, np=newNode(node[p].len+1, node[p].len+1);
        while(p && node[p].nxt[c]==0) node[p].nxt[c]=np, p=node[p].fail;
        if (p==0) node[np].fail=root; else{
            int q=node[p].nxt[c];
            if (node[p].len+1 == node[q].len){
                node[np].fail=q;
            }else{
                int nq=newNode(node[p].len+1, node[q].pos);
                for (int i=0; i<N_CHAR; i++) node[nq].nxt[i]=node[q].nxt[i];
                node[nq].fail=node[q].fail;
                node[q].fail=node[np].fail=nq;
                while(p && node[p].nxt[c]==q) node[p].nxt[c]=nq, p=node[p].fail;
            }
        }
        last=np;  node[np].cnt=1;
    }

    /**
     * Update Appear Count of States
     */
    inline void updateCount(){
        static int deg[MAXN*2]; for (int i=0; i<=numn; i++) deg[i]=0;
        static queue<int> que; while(!que.empty()) que.pop();
        for (int i=1; i<=numn; i++) deg[node[i].fail]++;
        for (int i=1; i<=numn; i++) if (deg[i]==0) que.push(i);
        while(!que.empty()){
            int x=que.front(); que.pop();
            int v=node[x].fail;
            node[v].cnt+=node[x].cnt;
            if (--deg[v]==0) que.push(v);
        }
    }

    /**
     * SG
     */
    vector<int> edge[MAXN*2];
    int sg[MAXN*2]; LL cnt[MAXN*2];
    bool vis[MAXN*2][N_CHAR+1];
    int top[MAXN*2], tot;
    inline void calc(){
        static int deg[MAXN*2]; for (int i=1; i<=numn; i++) deg[i]=0;
        static queue<int> que; while(!que.empty()) que.pop();
        for (int i=1; i<=numn; i++){
            sg[i]=0; vis[i][N_CHAR]=false;
            for (int j=0; j<N_CHAR; j++){
                if (node[i].nxt[j]) ++deg[i], edge[node[i].nxt[j]].push_back(i);
                vis[i][j]=false;
            }
        }
        for (int i=1; i<=numn; i++) if (deg[i]==0) que.push(i);
        tot=0;
        while(!que.empty()){
            int x=que.front(); que.pop();
            top[++tot]=x;
            while(vis[x][sg[x]]) sg[x]++;
            for (auto v : edge[x]){
                vis[v][sg[x]]=true;
                if (--deg[v] == 0) que.push(v);
            }
        }
    }
}samA, samB;

char strA[MAXN], strB[MAXN];
LL sum[MAXN*2];
char ansA[MAXN], ansB[MAXN];
int la, lb;
LL K;

inline void Print(){
    for (int i=1; i<=la; i++) printf("%c", ansA[i]); printf("\n");
    for (int i=1; i<=lb; i++) printf("%c", ansB[i]); printf("\n");
}

bool dfsB(int x, LL k, int g){
    int sg=samB.sg[x];
    if (sg!=g) --k;
    if (k==0) { Print(); return true; }
    for (int j=0; j<26; j++){
        int v=samB.node[x].nxt[j];
        if (v && k){
            if (samB.cnt[v]>=k){
                ansB[++lb]='a'+j;
                return dfsB(v, k, g);;
            }else k-=samB.cnt[v];
        }
    }
    return false;
}

inline void recalc(int v){
    for (int i=1; i<=samB.tot; i++){
        int x=samB.top[i]; samB.cnt[x]=(v!=samB.sg[x]);
        for (int j=0; j<26; j++){
            int v=samB.node[x].nxt[j];
            if (v) samB.cnt[x]+=samB.cnt[v];
        }
        // printf("#%d -> %d\n", x, samB.cnt[x]);
    }
}

bool dfsA(int x, LL k){
    int sg=samA.sg[x]; 
    if (k<=sum[sg]){ recalc(sg); return dfsB(samB.root, k, sg); }

    k-=sum[sg];
    for (int j=0; j<26; j++){
        int v=samA.node[x].nxt[j];
        if (v && k){
            if (samA.cnt[v]>=k){
                ansA[++la]='a'+j;
                return dfsA(v, k);;
            }else k-=samA.cnt[v];
        }
    }
    return false;
}


int main(){
    //freopen("in.txt", "r", stdin);
    scanf("%lld", &K); 
    scanf("%s%s", strA, strB);

    samA.init(), samB.init();
    for (int i=0, l=strlen(strA); i<l; i++) samA.addChar(strA[i]-'a');
    for (int i=0, l=strlen(strB); i<l; i++) samB.addChar(strB[i]-'a');
    samA.calc(), samB.calc();

    for (int i=1; i<=samA.tot; i++){
        int x=samB.top[i];
        samB.cnt[x]=1;
        for (int j=0; j<=26; j++){
            int v=samB.node[x].nxt[j];
            if (v) samB.cnt[x]+=samB.cnt[v];
        }
        // printf("%d => %d\n", x, samB.cnt[x]);
    }

    // Calc SG Sum
    for (int i=1; i<=samB.numn; i++){
        int sg=samB.sg[i];
        if (i!=samB.root) 
            sum[sg]+=samB.node[i].len-samB.node[samB.node[i].fail].len;
        else ++sum[sg];
    }
    for (int i=0; i<=26; i++) sum[i]=samB.cnt[samB.root]-sum[i];

    for (int i=1; i<=samA.tot; i++){
        int x=samA.top[i];
        samA.cnt[x]=sum[samA.sg[x]];
        for (int j=0; j<=26; j++){
            int v=samA.node[x].nxt[j];
            if (v) samA.cnt[x]+=samA.cnt[v];
        }
        // printf("%d => %d\n", x, samA.cnt[x]);
    }

    la=lb=0;
    if (!dfsA(samA.root, K)) printf("NO\n");

    return 0;
}
赞赏
https://secure.gravatar.com/avatar/f83b57c055136369e9feba5d6671d6b5?s=256&r=g

WNJXYK

文章作者

一个蒟蒻

推荐文章

发表评论

textsms
account_circle
email

WNJXYKのBlog

后缀自动机总结
后缀自动机的构建 简单说明 从Hihocoder里偷一张图和一个表格过来,这是字符串aabbabd的后缀自动机示意图,蓝色的是转移边,虚线是Fail边。 状态 子串 endpos S 空…
扫描二维码继续阅读
2018-08-17
<--! http2https -->