POJ 2778 AC自动机 + 矩阵乘法

什么鬼?,C++ AC, G++ 疯狂RE。。看题目就知道是递推的方法,但是肯定不能dp,看n可以达到20亿,自然想到用矩阵的乘法来递推。

题目链接

题意

给你一些字符串代表有病毒的字符串,现在给你一个数字N,让你求长度为N且不包含病毒的字符串的数目。

思路

我们首先需要构建一个矩阵,问题是矩阵的元素代表什么,我们用矩阵的元素代表从一个点转移到另一个点的路线数目,那么矩阵的K次幂,就代表了从这个点经过K次路径(可以重复)到达另一个点的路线的数目。因此我们把构建出的矩阵求出K次幂,然后把第一行累加一下就可。因为起点一定是0点,而末尾的点可能是任意一个顶点。

问题是矩阵的构建了。我们需要用到AC自动机的fai指针了,我们首先把每个病毒字符串的末尾标记一下,然后我们在bfs中加上一句,如果该点的fail指针所指向的字母是被标记的,那么这个字母也要被标记。因为fail指针指向的是公共的后缀,如果这个后缀都被标记了,那么自然这个字符串中一定包含了病毒串。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include <set>
#include <map>
#include <stack>
#include <queue>
#include <ctime>
#include <cmath>
#include <cstdio>
#include <vector>
#include <bitset>
#include <string>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#include <functional>
#define eps 1e-8
#define PI acos(-1.0)

using namespace std;
typedef long long ll;
struct Matrix{
ll m[101][101];
}M;
const int mod = 100000;
int n, cnt, nxt[105][5], fail[105], mark[105], id[5];
long long m;
int newnode()
{
++cnt;
for(int i = 0; i < 4; ++i)
{
nxt[cnt][i] = -1;
}
mark[cnt] = 0;
return cnt;
}



void init()
{
id['A'] = 0;
id['T'] = 1;
id['C'] = 2;
id['G'] = 3;
cnt = -1;
newnode();
}


void insert(char s[])
{
int sz = strlen(s);
int now = 0;
for(int i = 0; i < sz; ++i)
{
int num = id[s[i]];
if(nxt[now][num] == -1)
nxt[now][num] = newnode();
now = nxt[now][num];
}
mark[now] = 1;
}


void build()
{
fail[0] = 0;
queue<int>Q;
for(int i = 0; i < 4; ++i)
{
if(nxt[0][i] == -1)
nxt[0][i] = 0;
else{
fail[nxt[0][i]] = 0;
Q.push(nxt[0][i]);
}
}
while(!Q.empty()){
int now = Q.front();
Q.pop();
if(mark[fail[now]])
mark[now] = 1;
for(int i = 0; i < 4; ++i)
{
if(nxt[now][i] == -1)
{
nxt[now][i] = nxt[fail[now]][i];
}
else{
fail[nxt[now][i]] = nxt[fail[now]][i];
Q.push(nxt[now][i]);
}
}
}
}


void build_matrix()
{
for(int i = 0; i <= cnt; ++i)
{
for(int j = 0; j < 4; ++j)
{
if(!mark[i] && !mark[nxt[i][j]])
M.m[i][nxt[i][j]]++;
}
}
}

Matrix mul(Matrix A, Matrix B)
{
Matrix tmp;
memset(tmp.m, 0, sizeof(tmp));
for(int i = 0; i <= cnt; ++i)
{
for(int j = 0; j <= cnt; ++j)
{
tmp.m[i][j] = 0;
for(int k = 0; k <= cnt; ++k)
{
tmp.m[i][j] = (tmp.m[i][j] + A.m[i][k] * B.m[k][j]) % mod;
}
}
}
return tmp;
}

Matrix pow(Matrix A, long long n)
{
Matrix res;
memset(res.m, 0, sizeof(res));
for(int i = 0; i <= cnt; ++i)
res.m[i][i] = 1;
while(n)
{
if(n & 1)
res = mul(res, A);
A = mul(A,A);
n >>= 1;
}
return res;
}
int main()
{
scanf("%d %lld",&n,&m);
init();
char s[11];
for(int i = 1; i <= n; ++i)
{
scanf("%s",s);
insert(s);
}
build();
build_matrix();
Matrix C = pow(M, m);
ll sum = 0;
for(int i = 0; i <= cnt; ++i)
sum = (sum + C.m[0][i]) % mod;
cout << sum << '\n';
}