M-Mediocre String Problem
给定字符串S,T,求S的子串与T的前缀子串能够组成的回文串个数
题意
给定两个字符串$s ,t$,取$s$的子串$s’$和$t$的前缀子串$t’$,并使$|s’|>|t’|$.拼接$s’,t’$得到$str=s’+t’$,求能使$str$为回文串的总方案数。
思路
由于$|s’|>|t’|$,可令$s’=a+b,t’=c ,(|a|=|c|>0,|b|>0)$
因此$str=a+b+c$,由回文串性质可知,$b$为长度大于0的回文串,且$reverse(a)=c$
如,对于字符串$s=aabbcdedc,t=bbaa$,以$x=4$为例
$aabb|cdedc$
$aabb$
$;;abb$
$;;;;bb$
$;;;;;b$
$a,c$有以上4种取法,$b=c或b=cdedc$,共有2×4=8种情况
解法
对于$1≤i≤|s|$求出以$s$以第$i$位开头的回文串个数$CNT(i)$,可以采用Manacher,利用回文串性质差分求解;
翻转$s$,利用ex-KMP求解$reverse(s)$的后缀与$t$的最长公共前缀$LCP$;
对于原串$s$的第$x$位,能够组成的回文串个数为$LCP(x)·CNT(x+1)$,求和即为所求解.
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
#include <iostream> #include <cstring> #include <algorithm>
using namespace std;
const int maxn=1e6+10;
char s[maxn],t[maxn];
char tmp[maxn<<1]; int Len[maxn<<1],cnt[maxn];
int (char *str) { int i,len=strlen(str); tmp[0]='@'; for(int i=1;i<=2*len;i+=2) { tmp[i]='#'; tmp[i+1]=str[i/2]; } tmp[2*len+1]='#'; tmp[2*len+2]='$'; tmp[2*len+3]=0; return 2*len+1; }
int manacher(char *str) { int mx=0,ans=0,pos=0; int len=init(str); for(int i=1;i<=len;i++) { if(mx>i) Len[i]=min(mx-i,Len[2*pos-i]); else Len[i]=1; while(tmp[i-Len[i]]==tmp[i+Len[i]]) Len[i]++; if(Len[i]+i>mx) mx=Len[i]+i,pos=i; } for(int i=2;i<len;i++) { if(tmp[i]=='#'&&Len[i]==1) continue; int x=i/2-Len[i]/2,y=(Len[i]-1)/2; if((Len[i]-1)%2==0) y--; cnt[x]++; cnt[x+y+1]--; } }
int extend[maxn],nex[maxn];
void getNext(char *s) { int len=strlen(s); nex[0]=len; int pos=0; while(pos+1<len&&s[pos]==s[pos+1]) pos++; nex[1]=pos; int k=1,L; for(int i=2;i<len;i++) { pos=k+nex[k]-1; L=nex[i-k]; if(i+L<=pos) nex[i]=L; else { int j=pos-i+1; if(j<0) j=0; while(i+j<len&&s[i+j]==s[j]) j++; nex[i]=j; k=i; } } }
void getExtend(char *s,char *t) { int lens=strlen(s),lent=strlen(t); getNext(t); int pos=0; while(pos<lens&&pos<lent&&s[pos]==t[pos]) pos++; extend[0]=pos; int k=0,L; for(int i=1;i<lens;i++) { pos=k+extend[k]-1; L=nex[i-k]; if(i+L<=pos) extend[i]=L; else { int j=pos-i+1; if(j<0) j=0; while(i+j<lens&&j<lent&&s[i+j]==t[j]) j++; extend[i]=j; k=i; } } }
int main() { scanf("%s%s",s,t); memset(cnt,0,sizeof cnt); int lens=strlen(s),lent=strlen(t); manacher(s); for(int i=0;i<lens;i++) cnt[i]+=cnt[i-1]; reverse(s,s+lens); getExtend(s,t); long long ans=0; for(int i=1;i<lens;i++) ans+=1ll*cnt[lens-i]*extend[i]; printf("%lldn",ans); return 0; }
|
近期评论