[模板] AC自动机

注意:请在了解Trie树和KMP算法后再学习AC自动机

AC自动机可以用来解决多模式串匹配问题,与KMP算法相似,都是通过Fail指针进行失配边的操作。而因为AC自动机所进行的匹配是多模式串的,所以需要将模式串建立一个Trie树。

Fail指针的建立将在下面代码注释中介绍

#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<vector>
#define LL long long

using namespace std;
const int MAXN = 10000000 + 10;

struct Node{
    Node* Next[26]; //trie指针
    Node* fail;  //next指针
    int sum; //可能重复,出现几次
};

Node* root; //根节点
Node* queue[MAXN]; //求解next时需要的队列
int head,tail;
char key[100]; 
char P[MAXN];

void Insert(char *s){ //Trie树常规操作
    Node *tmp = root;

    for(int i=0;s[i];i++){
        int v = s[i] - 'a';

        if(tmp->Next[v] == NULL){
            tmp->Next[v] = new Node();
            for(int i=0;i<26;i++)
                tmp->Next[v]->Next[i] = 0;
            tmp->Next[v]->fail = 0;
            tmp->Next[v]->sum = 0;
        }
        tmp = tmp->Next[v];
    } 
    tmp->sum++;
}

void getFail(){
    head = 0;
    tail = 0;
    queue[tail++] = root;

    while(head<tail){
        Node* tmp = queue[head++];

        for(int i=0;i<26;i++){
            if(tmp -> Next[i]){ //当前节点存在Next[i]
                if(tmp == root){
                    tmp->Next[i]->fail = root; //根节点的儿子Fail指向根节点
                }else{
                    Node* p = tmp->fail; //p 指向当前节点的Fail

                    while(p){
                        if(p->Next[i]){
                            tmp->Next[i]->fail = p->Next[i]; //若p也存在Next[i],建立Fail
                            break;
                        }

                        p = p->fail; //不存在Next[i] 递归查找
                    } 

                    if(p == NULL)
                        tmp->Next[i]->fail = root; //查找失败,Fail指向根
                }
                queue[tail++] = tmp->Next[i]; //入队
            }
        }
    }
}

int ans;

void AC(char *s){
    Node* tmp = root;

    for(int i=0; s[i];i++){
        int v = s[i] - 'a';
        while(!tmp->Next[v] && tmp!=root)
            tmp = tmp->fail; //失配
        tmp = tmp -> Next[v];

        if(!tmp)
            tmp = root; //匹配失败

        Node* p = tmp;

        while(p!=root){
            if(p->sum >= 0){
                ans += p->sum;
                p->sum = -1;
            }else
                break;

            p = p->fail;
        }

    }
}

int n;
int main(){
    root = new Node();
    for(int i=0;i<26;i++)
        root->Next[i] = 0;
    root->fail = 0;
    root->sum = 0;

    scanf("%d",&n);

    for(int i=0;i<n;i++){
        scanf("%s",key);
        Insert(key); 
    }

    scanf("%s",P);
    ans = 0;
    getFail();
    AC(P);

    printf("%d",ans);
    return 0;
}