palindrome mouse

题目链接

题意

现在有一个字符串$s$,你从中可以找到多少个回文子串对,使其中一个是另一个的子串。

思路

在回文树上建出这个串,那么每个点所表示的回文串,包含它的祖先作为子串,另一部分还有所有fail指针指向的点,都是它的后缀。
但是要注意,有的时候这两个方向可能有重合的点,如$aaaa$。
往上找答案的时候加一个vis标记去记录是否已经加过这个点。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111


using namespace std;
typedef long long ll;

const int maxn = 1e5 + 5;

int _, kase = 1;
char s[maxn];
ll ans = 0;

struct {
int next[maxn][26];
int fail[maxn];
int len[maxn];
int S[maxn];
int dp[maxn];
bool vis[maxn];
int last, n, p;


int newNode(int l) {
memset(next[p], 0, sizeof(next[p]));
len[p] = l;
dp[p] = 0;
return p++;
}

void init() {
ans = 0;
n = last = p = 0;
newNode(0);
newNode(-1);
S[n] = -1;
fail[0] = 1;
}

int getFail(int x) {
while (S[n - len[x] - 1] != S[n]) {
x = fail[x];
}
return x;
}

void add(int c) {
S[++n] = c;
int cur = getFail(last);
if (!next[cur][c]) {
int now = newNode(len[cur] + 2);
fail[now] = next[getFail(fail[cur])][c];
next[cur][c] = now;
}
last = next[cur][c];
}

int jump(int x) {
int cnt = 0;
vis[x] = 1;
while (fail[x] != 0 && fail[x] != 1 && !vis[fail[x]]) {
x = fail[x];
vis[x] = 1, ++cnt;
}
return cnt;
}

void clearJump(int x, int cnt) {
vis[x] = 0;
while (cnt--) {
x = fail[x];
vis[x] = 0;
}
}

void dfs(int x, int fa) {
int jp = jump(x);
dp[x] = jp;
if (x != 1 && x != 0 && fa != 0 && fa != 1) {
dp[x] = dp[fa] + jp + 1;
}
ans += dp[x];
for (int i = 0; i < 26; ++i) {
if (next[x][i]) {
dfs(next[x][i], x);
}
}
clearJump(x, jp);
}

void build() {
init();
for (int i = 1; s[i]; i++) {
add(s[i] - 'a');
}
}

} pam;


int main() {
scanf("%d", &_);
while (_--) {
scanf("%s", s + 1);
pam.build();
printf("Case #%d: ", kase++);
pam.dfs(1, 1);
// printf("%lldn", ans);
pam.dfs(0, 0);
printf("%lldn", ans);
}
return 0;
}

甚至还可以在后缀树上建主席树。(待补

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 1e5 + 10, N = 1e5 + 10;
queue<int> q;
int rt[maxn], ls[maxn * 50], rs[maxn * 50], sum[maxn * 50], cnt;
#define mid (l + r) / 2
void up(int &o, int pre, int l, int r, int k) {
o = ++cnt;
sum[o] = sum[pre] + 1;
ls[o] = ls[pre];
rs[o] = rs[pre];
if (l == r)
return;
if (k <= mid)
up(ls[o], ls[pre], l, mid, k);
else
up(rs[o], rs[pre], mid+ 1, r, k);
}
int qu(int o, int l, int r, int k) {
if (l == r)
return sum[o];
if (k <= mid)
return qu(ls[o], l, mid, k);
return qu(rs[o], mid + 1, r, k);
}
struct ptree{
char s[maxn];
int next[maxn][26],fail[maxn],cnt[maxn],len[maxn], d[maxn];
int last,n,p;
long long res;
ll ans = 0;
inline int newnode(int l){
cnt[p]=0;
len[p]=l;
memset(next[p], 0, sizeof(next[p]));
return p++;
}
inline void init(){
n = last = p = ans = 0;
newnode(0),newnode(-1);
s[n]=-1;
fail[0]=1;
}
inline int FL(int x){
while(s[n-len[x]-1]!=s[n]) x=fail[x];
return x;
}
void add(char c){
c-='a';
s[++n]=c;
int cur=FL(last);
if(!next[cur][c]){
int now=newnode(len[cur]+2);
cnt[now] = 1;
rt[now] = rt[cur];
int FF = next[FL(fail[cur])][c];
fail[now]=next[FL(fail[cur])][c];
next[cur][c]=now;
while (FF > 1) {
if (qu(rt[now], 1, N, FF))
break;
up(rt[now], rt[now], 1, N, FF);
FF = fail[FF];
}
up(rt[now], rt[now], 1, N, now);
ans += sum[rt[now]] - 1;
}
last=next[cur][c];
}
} p;
char s[maxn];
int main(){
int T, Case = 0;
scanf("%d", &T);
while (T--) {
scanf("%s",s);
p.init();
cnt = 0;
for (int i = 0; s[i]; i++)
p.add(s[i]);
printf("Case #%d: %lldn", ++Case, p.ans);
}
}