bzoj

链接:https://www.lydsy.com/JudgeOnline/problem.php?id=2780
思路:喜闻乐见SAM, SA, AC都可以做的题,而且每种都是非常有代表性的,挨个讲一讲,篇幅比较长。

1.后缀自动机:

这个好像其实就算是广义后缀自动机?每次插完一个串后把last置为root,然后继续插下一个串。考虑询问串在SAM上跑,跑到最终节点,那么它fail树的子树上都是该子串的后缀,那么问询问串在多少个模板串中出现过,其实就是问最终节点的子树中有多少不同种类的模板串。这里一共有如下三种方法:
1.1:树上莫队:区间不同种类数,那么dfs序后转树上莫队是个比较常见的做法,复杂度O(n$sqrt{n}$)
1.2:SAM上暴力跳:在每次插模板串的每个字符后,都从该点开始沿fail树暴力跳到根节点,然后路上更新每个点是否出现过该种串,以及总的出现次数,最后找到询问串对应最终节点,访问答案即可。暴力跳的复杂度已经证明过也是O(n$sqrt{n}$)的。
1.3:dfs序 + 树状数组:区间种类数还有树状数组的做法,按右端点排序,然后统计出每种串相同的最近前面出现的位置,离线树状数组统计答案即可。复杂度O(nlogn)。

2.后缀数组:

后缀数组就是老套路,把所有串拼起来,中间用一个没出现过的字符分割来,然后找到每一个询问串对应的位置,两边二分找到最远的lcp >= 当前len的位置,在这个l到r区间里统计模板串不同出现次数,两种做法:
2.1:莫队:变成了区间里的莫队,统计即可。
2.2:dfs序 + 树状数组:同上

3.AC自动机:

AC自动机可能是最容易错的一种做法,因为它的fail树和SAM的fail树实在是容易混淆。AC自动机的fail树是最长后缀与前缀相同。我们对询问串建立AC自动机,拿模板串在上面跑,跑过的所有点都要++,然后再统计子树信息即可,方法同上。
3.1:莫队:变成了区间里的莫队,统计即可。
3.2:dfs序 + 树状数组:同上

代码:
SAM暴力跳:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

#define sigma_size 28
using namespace std;
const int maxn = 4e5 + 5;
const int mod = 1e9 + 7;
typedef long long ll;
int len[maxn * 2];
int f[maxn * 2]; //后缀链接(最短串前部减少一个字符所到达的状态)
int cnt[maxn * 2]; //被后缀连接的数
int ch[maxn * 2][sigma_size]; //状态转移(尾部加一个字符的下一个状态)(图)
int idx; //节点编号
int last; //最后节点
ll epos[maxn * 2]; // enpos数(该状态子串出现数量)
char s[maxn];
int vis[maxn * 2];
int res[maxn * 2];
int n, q;

void () { //初始化
last = idx = 1; //1表示root起始点 空集
f[1] = len[1] = 0;
memset(ch[1], 0, sizeof(ch[1]));
}

//SAM建图
void add(int c, int id) { //插入字符,为字符ascll码值
int x = ++idx; //创建一个新节点x;
len[x] = len[last] + 1; // 长度等于最后一个节点+1
epos[x] = 1; //接受节点子串除后缀连接还需加一
int p; //第一个有C转移的节点;
for (p = last; p && !ch[p][c]; p = f[p])ch[p][c] = x;//沿着后缀连接 将所有没有字符c转移的节点直接指向新节点
if (!p)f[x] = 1, cnt[1]++; //全部都没有c的转移 直接将新节点后缀连接到起点
else {
int q = ch[p][c]; //p通过c转移到的节点
if (len[p] + 1 == len[q]) //pq是连续的
f[x] = q, cnt[q]++; //将新节点后缀连接指向q即可,q节点的被后缀连接数+1
else {
int nq = ++idx; //不连续 需要复制一份q节点
len[nq] = len[p] + 1; //令nq与p连续
vis[nq] = vis[q]; res[nq] = res[q]; //别忘了复制多添加的信息。
f[nq] = f[q]; //因后面link[q]改变此处不加cnt
memcpy(ch[nq], ch[q], sizeof(ch[q])); //复制q的信息给nq
for (; p && ch[p][c] == q; p = f[p])
ch[p][c] = nq; //沿着后缀连接 将所有通过c转移为q的改为nq
f[q] = f[x] = nq; //将x和q后缀连接改为nq
cnt[nq] += 2; // nq增加两个后缀连接
}
}
last = x; //更新最后处理的节点
for(; vis[x] != id && x; x = f[x]){
vis[x] = id, res[x]++;
}
}

int main() {
scanf("%d%d", &n, &q);
init();
for(int i = 1; i <= n; i++){
scanf("%s", s + 1);
last = 1;
for(int j = 1, len = strlen(s + 1); j <= len; j++) add(s[j] - 'a', i);
}
while(q--){
scanf("%s", s + 1);
int now = 1;
int i, len = strlen(s + 1);
for(i = 1; i <= len; i++){
if(ch[now][s[i] - 'a']) now = ch[now][s[i] - 'a'];
else break;
}
if(i == len + 1) printf("%dn", res[now]);
else puts("0");
}
return 0;
}

SAM + dfs序 + 树状数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

#define sigma_size 28
using namespace std;
const int maxn = 1e5 + 5;
const int mod = 1e9 + 7;
typedef long long ll;
int len[maxn * 2];
int f[maxn * 2]; //后缀链接(最短串前部减少一个字符所到达的状态)
int cnt[maxn * 2]; //被后缀连接的数
int ch[maxn * 2][sigma_size]; //状态转移(尾部加一个字符的下一个状态)(图)
int idx; //节点编号
int last; //最后节点
char s[maxn];
vector<int> sta[maxn * 2];
int n, m;
int in[maxn * 2];
int sz;
vector<int> G[maxn * 2];
int pre[maxn * 2];
vector<int> ne[maxn * 2];
int out[maxn * 2];
int c[maxn * 2];
int id[maxn * 2];

struct node{
int l, r, id;
bool operator <(const node &x) const{
return r < x.r;
}
}q[maxn];

int res[maxn];

int lowbit(int x){
return x & (-x);
}

void update(int x, int d){
while(x < maxn * 2){
c[x] += d;
x += lowbit(x);
}
}

int query(int x){
int ret = 0;
while(x){
ret += c[x];
x -= lowbit(x);
}
return ret;
}

void () { //初始化
last = idx = 1; //1表示root起始点 空集
f[1] = len[1] = 0;
memset(ch[1], 0, sizeof(ch[1]));
}

//SAM建图
void add(int c, int id) { //插入字符,为字符ascll码值
int x = ++idx; //创建一个新节点x;
len[x] = len[last] + 1; // 长度等于最后一个节点+1
int p; //第一个有C转移的节点;
for (p = last; p && !ch[p][c]; p = f[p])ch[p][c] = x;//沿着后缀连接 将所有没有字符c转移的节点直接指向新节点
if (!p)f[x] = 1, cnt[1]++; //全部都没有c的转移 直接将新节点后缀连接到起点
else {
int q = ch[p][c]; //p通过c转移到的节点
if (len[p] + 1 == len[q]) //pq是连续的
f[x] = q, cnt[q]++; //将新节点后缀连接指向q即可,q节点的被后缀连接数+1
else {
int nq = ++idx; //不连续 需要复制一份q节点
len[nq] = len[p] + 1; //令nq与p连续
f[nq] = f[q]; //因后面link[q]改变此处不加cnt
memcpy(ch[nq], ch[q], sizeof(ch[q])); //复制q的信息给nq
for (; p && ch[p][c] == q; p = f[p])
ch[p][c] = nq; //沿着后缀连接 将所有通过c转移为q的改为nq
f[q] = f[x] = nq; //将x和q后缀连接改为nq
cnt[nq] += 2; // nq增加两个后缀连接
}
}
last = x; //更新最后处理的节点
sta[x].push_back(id);
}

void dfs(int u, int fa){
in[u] = ++sz;
id[sz] = u;
for(int i = 0; i < G[u].size(); i++){
int v = G[u][i];
if(v == fa) continue;
dfs(v, u);
}
out[u] = sz;
}

int main() {
scanf("%d %d", &n, &m);
init();
for(int i = 1; i <= n; i++){
scanf("%s", s);
last = 1;
for(int j = 0, len = strlen(s); j < len; j++) add(s[j] - 'a', i);
}
for(int i = 1; i <= idx; i++) {
G[f[i]].push_back(i);
}
dfs(1, 0);
for(int i = 1; i <= m; i++){
scanf("%s", s);
int now = 1;
for(int j = 0, len = strlen(s); j < len && now; j++){
now = ch[now][s[j] - 'a'];
}
q[i].l = in[now], q[i].r = out[now], q[i].id = i;
}

for(int i = 1; i <= sz; i++){
for(int j = 0; j < sta[id[i]].size(); j++){
if(pre[sta[id[i]][j]]) ne[i].push_back(pre[sta[id[i]][j]]);
pre[sta[id[i]][j]] = i;
}
}

sort(q + 1, q + m + 1);
int j = 1;
for(int i = 1; i <= m; i++){
if(q[i].r == 0) continue;
while(j <= sz && j <= q[i].r){
for(int k = 0; k < ne[j].size(); k++){
update(ne[j][k], -1);
}
update(j, sta[id[j]].size());
j++;
}
res[q[i].id] = query(q[i].r) - query(q[i].l - 1);
}
for(int i = 1; i <= m; i++) cout << res[i] << 'n';
return 0;
}

SA + 树状数组(有点卡常过不了)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#include <bits/stdc++.h>
using namespace std;
const int maxn = 600005;
int s[maxn];
int sa[maxn],t[maxn],t2[maxn],c[maxn], n,m, tt;
int r[maxn], h[maxn];
int d[maxn][25];
char ch[maxn];

//n为原字符长度,放于s[]的0-n-1位置,s[n]为0,build_sa传入的参数为(n+1,字符范围)

void build_sa(int n, int m) {//n为原串长度+1,字符值在0-m-1
int i, *x = t, *y = t2;
for (i = 0; i < m; i++)c[i] = 0;
for (i = 0; i < n; i++)c[x[i] = s[i]]++;
for (i = 1; i < m; i++)c[i] += c[i - 1];
for (i = n - 1; i >= 0; i--)sa[--c[x[i]]] = i;
for (int k = 1; k <= n; k <<= 1) {
int p = 0;
for (i = n - k; i < n; i++)y[p++] = i;
for (i = 0; i < n; i++)if (sa[i] >= k)y[p++] = sa[i] - k;
for (i = 0; i < m; i++)c[i] = 0;
for (i = 0; i < n; i++)c[x[y[i]]]++;
for (i = 0; i < m; i++)c[i] += c[i - 1];
for (i = n - 1; i >= 0; i--)sa[--c[x[y[i]]]] = y[i];
swap(x, y);
p = 1;
x[sa[0]] = 0;
for (i = 1; i < n; i++)
x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++;
if (p >= n)break;
m = p;
}
}

//最好从0开始,这样sa[n-1] = 0,可以把上一次的数据清空
void getheight() {
int i, j, k = 0;
for (i = 1; i <= n; i++) r[sa[i]] = i;
for (i = 0; i < n; h[r[i++]] = k)
for (k ? k-- : 0, j = sa[r[i] - 1]; s[i + k] == s[j + k]; k++);
}

struct node{
int l, r, id;
bool operator <(const node &x) const{
return r < x.r;
}
}q[60010];

int res[60010], tmp[maxn], pos[maxn], sum[maxn];

int lowbit(int x){
return x & (-x);
}

void add(int x, int d){
while(x < maxn){
sum[x] += d;
x += lowbit(x);
}
}

int query(int x){
int ret = 0;
while(x){
ret += sum[x];
x -= lowbit(x);
}
return ret;
}

void RMQ_init() {
for (int i = 1; i <= n; i++)d[i][0] = h[i];
for (int j = 1; (1 << j) <= n; j++)
for (int i = 1; i + (1 << j) - 1 <= n; i++)
d[i][j] = min(d[i][j - 1], d[i + (1 << (j - 1))][j - 1]);
}

int RMQ(int l, int r) {
if (l > r)swap(l, r);
int k = 0;
while (1 << (k + 1) <= r - l + 1)k++;
return min(d[l][k], d[r - (1 << k) + 1][k]);
}


int ask[60010];
int len[60010];
int pre[maxn], ne[maxn];

int main(){
scanf("%d %d", &tt, &m);
for(int i = 1; i <= tt; i++){
scanf("%s", ch);
for(int j = 0, len = strlen(ch); j < len; j++){
s[n] = ch[j] - 'a' + 1;
tmp[n++] = i;
}
s[n++] = 30 + i;
}
for(int i = 1; i <= m; i++){
scanf("%s", ch);
ask[i] = n;
len[i] = strlen(ch);
for(int j = 0; j < len[i]; j++){
s[n++] = ch[j] - 'a' + 1;
}
s[n++] = 30 + tt + i;
}
s[n] = 0;
build_sa(n + 1, tt + m + 35);
getheight();
RMQ_init();
for(int i = 1; i <= n; i++) pos[i] = tmp[sa[i]];

for(int i = 1; i <= m; i++){
int id = r[ask[i]];
int lb = 1, ub = id - 1, pl = id, pr = id;
if(h[id] < len[i]){
q[i].l = q[i].r = id;
q[i].id = i;
continue;
}
while(ub >= lb){
int mid = ub + lb >> 1;
if(RMQ(mid + 1, id) >= len[i]) ub = mid - 1, pl = mid;
else lb = mid + 1;
}
lb = id + 1, ub = n;

while(ub >= lb){
int mid = ub + lb >> 1;
if(RMQ(id + 1, mid) >= len[i]) lb = mid + 1, pr = mid;
else ub = mid - 1;
}
q[i].l = pl, q[i].r = pr, q[i].id = i;
}

sort(q + 1, q + m + 1);
for(int i = 1; i <= n; i++){
ne[i] = pre[pos[i]];
pre[pos[i]] = i;
}

int j = 1;
for(int i = 1; i <= m; i++){
while(j <= n && j <= q[i].r){
if(!pos[j]) {
j++;
continue;
}
if(ne[j]) add(ne[j], -1);
add(j, 1);
j++;
}
res[q[i].id] = query(q[i].r) - query(q[i].l - 1);
}
for(int i = 1; i <= m; i++) cout << res[i] << 'n';
return 0;
}

AC自动机 + dfs序 + 树状数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

using namespace std;

const int maxnode = 6e5 + 100;
const int sigma_size = 26;
int n, m;

string str[10010];
string s;

typedef long long ll;
int ch[maxnode][sigma_size];
int val[maxnode];
int f[maxnode];
int last[maxnode];
int sz;
vector<int> pos[maxnode];
vector<int> sta[maxnode], ne[maxnode];
ll c[maxnode];
vector<int> G[maxnode];
int in[maxnode], out[maxnode], id;
int pre[maxnode], rk[maxnode];
bool vis[maxnode];

int lowbit(int x){
return x & (-x);
}

void add(int x, int d){
while(x < maxnode){
c[x] += d;
x += lowbit(x);
}
}

ll query(int x){
int ret = 0;
while(x){
ret += c[x];
x -= lowbit(x);
}
return ret;
}

void () {
sz = 1;
memset(ch[0], 0, sizeof(ch[0]));
}

int idx(char c) {
return c - 'a';
}

void insert(string s ,int v, int x) {
int u = 0;
for (int i = 0; i < s.size(); i++) {
int c = idx(s[i]);
if (!ch[u][c]) {
ch[u][c] = sz;
memset(ch[sz], 0, sizeof(ch[sz]));
val[sz++] = 0;
}
u = ch[u][c];
}
val[u] = v;
pos[u].push_back(x);
}

void getfail() {
queue<int> q;
f[0] = 0;
for (int i = 0; i < sigma_size; i++) {
int u = ch[0][i];
if (u) {
f[u] = last[u] = 0;
q.push(u);
}
}
while (!q.empty()) {
int r = q.front();
q.pop();
//val[r] |= val[f[r]]; //AC自动机 + dp时会用
for (int c = 0; c < sigma_size; c++) {
int u = ch[r][c];
if (!u) {
ch[r][c] = ch[f[r]][c];
continue;
}
q.push(u);
int v = f[r];
while (v && !ch[v][c])v = f[v];
f[u] = ch[v][c];
last[u] = val[f[u]] ? f[u] : last[f[u]];
}
}
}


void find(string s, int x) {
int j = 0;
for (int i = 0; i < s.size(); i++) {
int c = idx(s[i]);
j = ch[j][c];
sta[j].push_back(x); //跑过的所有点都要统计信息
}
}

void dfs(int u, int fa){
vis[u] = 1;
in[u] = ++id;
rk[id] = u;
for(int i = 0; i < G[u].size(); i++){
int v = G[u][i];
if(v == fa) continue;
dfs(v, u);
}
out[u] = id;
}

struct node{
int l, r, id;
bool operator < (const node &x) const{
return r < x.r;
}
}q[60010];
ll res[60010];

int main(){
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> n >> m;
init();
for(int i = 1; i <= n; i++){
cin >> str[i];
}
for(int i = 1; i <= m; i++) {
cin >> s;
insert(s, 1, i);
}
getfail();
for(int i = 1; i <= n; i++){
find(str[i], i);
}

for(int i = 1; i < sz; i++) G[f[i]].push_back(i);
for(int i = 0; i < sz; i++) {
if(!vis[i]) dfs(i, 0);
}

for(int i = 1; i <= id; i++){
if(val[rk[i]]){
for(int j = 0; j < pos[rk[i]].size(); j++) {
int k = pos[rk[i]][j];
q[k].l = in[rk[i]], q[k].r = out[rk[i]], q[k].id = k;
}
}
}

sort(q + 1, q + m + 1);
for(int i = 1; i <= id; i++){
for(int j = 0; j < sta[rk[i]].size(); j++){
if(pre[sta[rk[i]][j]]) ne[i].push_back(pre[sta[rk[i]][j]]);
pre[sta[rk[i]][j]] = i;
}
}
int j = 1;
for(int i = 1; i <= m; i++){
while(j <= id && j <= q[i].r) {
for (int k = 0; k < ne[j].size(); k++) {
add(ne[j][k], -1);
}
add(j, sta[rk[j]].size());
j++;
}
res[q[i].id] = query(q[i].r) - query(q[i].l - 1);
}
for(int i = 1; i <= m; i++) cout << res[i] << 'n';
return 0;
}