hihoCoder Trie图(AC自动机)

http://hihocoder.com/contest/hiho4/problem/1

原来AC自动机是Trie图啊。。。
Trie图与KMP的作用相似,都是求字符串匹配,而且感觉求法也是类似,不过KMP是单个模式串求匹配,Trie图是多个模式串求匹配,为了实现这个功能,Tire图多了一个前缀指针,类似KMP的失配数组

首先按照Tire树建立好一个树,然后按照深度优先求取前面深度的前缀指针,这个前缀指针是根据最大后缀求得。
hihoCoder Trie图(AC自动机)
c字符后面的节点A的前缀指向B’,因为abc的最大后缀和bc相同

查找时,如果一个节点的子节点没有后续的串,那么就跳到他的前缀指针,在查找直至到根节点。

代码:

代码借鉴

https://www.cnblogs.com/vongang/archive/2012/07/24/2606494.html

typedef struct Trie{
        Trie *fail;
        Trie *next[26];
        int cnt;
        Trie() {
            memset(next, 0, sizeof(next));
            fail = NULL;
            cnt = 0;
        }
}TrieNode, *LinkTrie;

LinkTrie root;
int head, tail;

void init() {
    root = new Trie();
    head = tail = 0;
}

void insert(char *st) {
    LinkTrie p = root;
    while(*st) {
        if(p->next[*st-'a'] == NULL)
            p->next[*st-'a'] = new Trie();
        p = p->next[*st-'a'];
        st ++;
    }
    p->cnt++;
}

void build() {
    root->fail = NULL;
    deque<LinkTrie> q;
    q.push_back(root);
    while(!q.empty()) {
        LinkTrie tmp = q.front();
        LinkTrie p = NULL;
        q.pop_front();
        for (int i = 0; i < 26; i ++) {
            if(tmp->next[i] != NULL) {
                if(tmp == root) tmp->next[i]->fail = root;
                else {
                    p = tmp->fail;
                    while(p != NULL) {
                        if(p->next[i] != NULL) {
                            tmp->next[i]->fail = p->next[i];
                            break;
                        }
                        p = p->fail;
                    }
                    if(p == NULL) tmp->next[i]->fail = root;
                }
                q.push_back(tmp->next[i]);
            }
        }
    }
}

int search(char *st) {
    int cnt = 0, t;
    LinkTrie p = root;
    while(*st) {
        t = *st - 'a';
        while(p->next[t] == NULL && p != root)
            p = p->fail;
        p = p->next[t];
        if(p == NULL) p = root;
        LinkTrie tmp = p;
        while(tmp != root && tmp->cnt != -1) {//一个模式串只被计算一次
            cnt += tmp->cnt;
            tmp->cnt = -1;
            tmp = tmp->fail;
        }
        st ++;
    }
    return cnt;
}

AC代码:

#include<bits/stdc++.h>
using namespace std;
#define LL long long
const int inf = 0x3f3f3f3f;
const int maxn = 1e5 + 5;

typedef struct Trie{
        Trie *fail;
        Trie *next[26];
        int cnt;
        Trie() {
            memset(next, 0, sizeof(next));
            fail = NULL;
            cnt = 0;
        }
}TrieNode, *LinkTrie;

LinkTrie root;
int head, tail;

void init() {
    root = new Trie();
    head = tail = 0;
}

void insert(char *st) {
    LinkTrie p = root;
    while(*st) {
        if(p->next[*st-'a'] == NULL)
            p->next[*st-'a'] = new Trie();
        p = p->next[*st-'a'];
        st ++;
    }
    p->cnt++;
}

void build() {
    root->fail = NULL;
    deque<LinkTrie> q;
    q.push_back(root);
    while(!q.empty()) {
        LinkTrie tmp = q.front();
        LinkTrie p = NULL;
        q.pop_front();
        for (int i = 0; i < 26; i ++) {
            if(tmp->next[i] != NULL) {
                if(tmp == root) tmp->next[i]->fail = root;
                else {
                    p = tmp->fail;
                    while(p != NULL) {
                        if(p->next[i] != NULL) {
                            tmp->next[i]->fail = p->next[i];
                            break;
                        }
                        p = p->fail;
                    }
                    if(p == NULL) tmp->next[i]->fail = root;
                }
                q.push_back(tmp->next[i]);
            }
        }
    }
}

int search(char *st) {
    int cnt = 0, t;
    LinkTrie p = root;
    while(*st) {
        t = *st - 'a';
        while(p->next[t] == NULL && p != root)
            p = p->fail;
        p = p->next[t];
        if(p == NULL) p = root;
        LinkTrie tmp = p;
        while(tmp != root && tmp->cnt != -1) {//一个模式串只被计算一次
            cnt += tmp->cnt;
            tmp->cnt = -1;
            tmp = tmp->fail;
        }
        st ++;
    }
    return cnt;
}

int main(int argc, char const *argv[]) {
    init();
    char s[1000005];
    int n;
    scanf("%d", &n);
    while(n --) {
        scanf("%s", s);
        insert(s);
    }
    build();
    scanf("%s", s);
    printf("%s\n", search(s) ? "YES" : "NO");
    return 0;
}