[bzoj 3926][zjoi 2015] 诸神眷顾的幻想乡

题目大意

给定一棵含 (n) 个节点的树,每个节点有一个权值 (c_i)。在树上选定两个点,顺序得到由节点权值组成的数列,求一共能得到多少不同的数列。

只与一个节点相连的节点不超过 (20) 个。

(1 leqslant n leqslant 100,000)

(1 leqslant c_i leqslant 10)

题目链接

BZOJ 3926

LibreOJ 2137

题解

只与一个节点相连的节点不超过 (20) 个,意味着叶子最多有 (20) 个。

以每个叶子为起点遍历整颗树,沿途建立广义 SAM。用 SAM 求本质不同子串数就是裸题了。

代码

LibreOJ 的 64 位机 + 内存比 BZOJ 少一半卡了我 SAM 的指针式写法。。。

BZOJ(指针式的 SAM):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

#include <vector>
#include <algorithm>
const int MAXN = 100005;
const int CHAR_SET = 10;
struct {
struct Node {
Node *c[CHAR_SET], *next;
int max;
Node(int max = 0) : max(max), c(), next(NULL) {}
int getMin() {
return next->max + 1;
}
} *start, _pool[MAXN * 40], *_curr;
SuffixAutomaton() {
init();
}
void init() {
_curr = _pool;
start = new (_curr++) Node();
}
Node *extend(Node *v, int c) {
if (v->c[c] && v->c[c]->max == v->max + 1) return v->c[c];
Node *u = new (_curr++) Node(v->max + 1);
while (v && !v->c[c]) {
v->c[c] = u;
v = v->next;
}
if (!v) {
u->next = start;
} else if (v->c[c]->max == v->max + 1) {
u->next = v->c[c];
} else {
Node *n = new (_curr++) Node(v->max + 1), *o = v->c[c];
std::copy(o->c, o->c + CHAR_SET, n->c);
n->next = o->next;
o->next = u->next = n;
for (; v && v->c[c] == o; v = v->next) v->c[c] = n;
}
return u;
}
long long calc() {
long long res = 0;
for (Node *p = _pool + 1; p != _curr; p++) res += p->max - p->getMin() + 1;
return res;
}
} sam;
struct Edge;
struct Node {
Edge *e;
int c, deg;
} N[MAXN];
struct Edge {
Node *u, *v;
Edge *next;
Edge() {}
Edge(Node *u, Node *v) : u(u), v(v), next(u->e) {}
} _pool[MAXN << 1], *_curr = _pool;
void addEdge(int u, int v) {
N[u].e = new (_curr++) Edge(&N[u], &N[v]);
N[v].e = new (_curr++) Edge(&N[v], &N[u]);
N[u].deg++;
N[v].deg++;
}
void dfs(Node *u, Node *fa, SuffixAutomaton::Node *last) {
SuffixAutomaton::Node *v = sam.extend(last, u->c);
for (Edge *e = u->e; e; e = e->next) if (e->v != fa) dfs(e->v, u, v);
}
int main() {
int n;
scanf("%d %*d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &N[i].c);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
addEdge(u, v);
}
for (int i = 1; i <= n; i++) if (N[i].deg == 1) dfs(&N[i], NULL, sam.start);
printf("%lldn", sam.calc());
return 0;
}

LibreOJ(非指针的 SAM):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

#include <vector>
#include <algorithm>
const int MAXN = 100005;
const int CHAR_SET = 10;
struct {
struct Node {
int c[CHAR_SET], next;
int max;
Node() : max(0), next(-1) {
std::fill(c, c + CHAR_SET, -1);
}
} N[MAXN * 40];
int start, nodeCnt;
SuffixAutomaton() {
init();
}
void init() {
start = 0;
nodeCnt = 1;
}
int getMin(int u) {
return N[N[u].next].max + 1;
}
int extend(int v, int c) {
if (N[v].c[c] != -1 && N[N[v].c[c]].max == N[v].max + 1) return N[v].c[c];
int u = nodeCnt++;
N[u].max = N[v].max + 1;
while (v != -1 && N[v].c[c] == -1) {
N[v].c[c] = u;
v = N[v].next;
}
if (v == -1) {
N[u].next = start;
} else if (N[N[v].c[c]].max == N[v].max + 1) {
N[u].next = N[v].c[c];
} else {
int n = nodeCnt++, o = N[v].c[c];
N[n].max = N[v].max + 1;
std::copy(N[o].c, N[o].c + CHAR_SET, N[n].c);
N[n].next = N[o].next;
N[o].next = N[u].next = n;
for (; v != -1 && N[v].c[c] == o; v = N[v].next) N[v].c[c] = n;
}
return u;
}
long long calc() {
long long res = 0;
for (int p = 1; p != nodeCnt; p++) res += N[p].max - getMin(p) + 1;
return res;
}
} sam;
struct Edge;
struct Node {
Edge *e;
int c, deg;
} N[MAXN];
struct Edge {
Node *u, *v;
Edge *next;
Edge() {}
Edge(Node *u, Node *v) : u(u), v(v), next(u->e) {}
} _pool[MAXN << 1], *_curr = _pool;
void addEdge(int u, int v) {
N[u].e = new (_curr++) Edge(&N[u], &N[v]);
N[v].e = new (_curr++) Edge(&N[v], &N[u]);
N[u].deg++;
N[v].deg++;
}
void dfs(Node *u, Node *fa, int last) {
int v = sam.extend(last, u->c);
for (Edge *e = u->e; e; e = e->next) if (e->v != fa) dfs(e->v, u, v);
}
int main() {
int n;
scanf("%d %*d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &N[i].c);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
addEdge(u, v);
}
for (int i = 1; i <= n; i++) if (N[i].deg == 1) dfs(&N[i], NULL, sam.start);
printf("%lldn", sam.calc());
return 0;
}