隐藏
「bzoj4175」小G的电话本 - 后缀自动机+NTT | Bill Yang's Blog

路终会有尽头,但视野总能看到更远的地方。

0%

「bzoj4175」小G的电话本 - 后缀自动机+NTT

题目大意

    小G是一个商人,他有一个电话本。电话本上记下了许多联系人,如timesqrorzyhb等等。不过Tony对其中的某个联系人的名字$S$特别感兴趣,他从中提取出了这个联系人的名字中的所有片段,如提取出orzo,r,z,or,rz,orz等等。现在他想请你统计有多少个长度为$k$的片段对$(P[1],P[2],P[3],…,P[k])$,使得在该片段对中所有片段在$S$中出现次数之和为他的幸运数$m$?注意两个片段对不同当且仅当两个片段对的某一位的片段不同,两个片段不同当且仅当这两个片段在$S$中的位置不同。


题目分析

首先我们建立后缀自动机,找出每个串的出现次数。
接着我们发现问题可以转化为一个背包问题。
设$f[i,j]$表示选了$i$个片段,占用的次数和为$j$的总数。

接着我们发现这个式子是卷积形式,可以使用FFT优化。
构造生成函数,第$i$位表示占用次数和为$i$的数目,对其做$k$次自乘,第$m$位即为答案。

$k$次自乘上快速幂。
时间复杂度为$O(n\log n\log k)$。
还可以使用多项式$exp+lnp$,时间复杂度$O(n\log n)$。

实测,快速幂要快一些。

注意卡常数。
生成函数大小设成$m$即可,$n$太大了会T。
不要开long long,用int就好。
模数特殊,可以上NTT。


代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include<algorithm>
#include<iostream>
#include<iomanip>
#include<cstring>
#include<cstdlib>
#include<climits>
#include<vector>
#include<cstdio>
#include<cmath>
#include<queue>
using namespace std;

inline const int Get_Int() {
int num=0,bj=1;
char x=getchar();
while(x<'0'||x>'9') {
if(x=='-')bj=-1;
x=getchar();
}
while(x>='0'&&x<='9') {
num=num*10+x-'0';
x=getchar();
}
return num*bj;
}

const int maxn=262144+5,maxc=26;
const int mod=1005060097;
int g=5;

void check(int& x) {
if(x<0)x+=mod;
if(x>=mod)x-=mod;
}

void add(int &x,int v) {
x+=v;
check(x);
}

struct SuffixAutomaton {
const static int maxn=100005;
int cnt,root,last;
int next[maxn<<1],Max[maxn<<1];
int end_pos[maxn<<1];
int child[maxn<<1][maxc],Bucket[maxn<<1],top[maxn<<1];
SuffixAutomaton() {
cnt=0;
root=last=newnode(0);
}
int newnode(int val) {
cnt++;
Max[cnt]=val;
return cnt;
}
void insert(int data) {
int p=last,u=newnode(Max[last]+1);
last=u;
end_pos[u]=1;
for(; p&&!child[p][data]; p=next[p])child[p][data]=u;
if(!p)next[u]=root;
else {
int old=child[p][data];
if(Max[old]==Max[p]+1)next[u]=old;
else {
int New=newnode(Max[p]+1);
copy(child[old],child[old]+maxc,child[New]);
next[New]=next[old];
next[u]=next[old]=New;
for(; child[p][data]==old; p=next[p])child[p][data]=New;
}
}
}
void build(string s) {
for(auto x:s)insert(x-'a');
}
void topsort() {
for(int i=1; i<=cnt; i++)Bucket[Max[i]]++;
for(int i=1; i<=cnt; i++)Bucket[i]+=Bucket[i-1];
for(int i=1; i<=cnt; i++)top[Bucket[Max[i]]--]=i;
}
void get_end_pos() {
topsort();
for(int i=cnt; i>=1; i--)add(end_pos[next[top[i]]],end_pos[top[i]]);
}
} sam;

int Quick_Pow(int a,int b) {
int sum=1;
for(; b; b>>=1,a=1ll*a*a%mod)if(b&1)sum=1ll*sum*a%mod;
return sum;
}

int inv(int x) {
return Quick_Pow(x,mod-2);
}

struct NumberTheoreticTransform {
int n,rev[maxn];
int omega[maxn],iomega[maxn];

void init(int n) {
this->n=n;
int x=Quick_Pow(g,(mod-1)/n);
omega[0]=iomega[0]=1;
for(int i=1; i<n; i++) {
omega[i]=1ll*omega[i-1]*x%mod;
iomega[i]=inv(omega[i]);
}
int k=log2(n);
for(int i=0; i<n; i++) {
int t=0;
for(int j=0; j<k; j++)if(i&(1<<j))t|=(1<<(k-j-1));
rev[i]=t;
}
}

void transform(int* a,int* omega) {
for(int i=0; i<n; i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int len=2; len<=n; len<<=1) {
int mid=len>>1;
for(int* p=a; p!=a+n; p+=len)
for(int i=0; i<mid; i++) {
int t=1ll*omega[n/len*i]*p[mid+i]%mod;
p[mid+i]=p[i]-t,check(p[mid+i]);
add(p[i],t);
}
}
}

void dft(int* a) {
transform(a,omega);
}

void idft(int* a) {
transform(a,iomega);
int x=inv(n);
for(int i=0; i<n; i++)a[i]=1ll*a[i]*x%mod;
}
} ntt;

int k,m;
int a[maxn],b[maxn];
char s[maxn];

int main() {
k=Get_Int();
m=Get_Int();
scanf("%s",s);
sam.build(s);
sam.get_end_pos();
for(int i=2; i<=sam.cnt; i++)if(sam.end_pos[i]<=m)b[sam.end_pos[i]]=(b[sam.end_pos[i]]+1ll*sam.end_pos[i]*(sam.Max[i]-sam.Max[sam.next[i]]))%mod;
int N=1;
while(N<((m+1)<<1))N<<=1;
ntt.init(N);
ntt.dft(b);
a[0]=1;
for(; k; k>>=1) {
if(k&1) {
ntt.dft(a);
for(int i=0; i<N; i++)a[i]=1ll*a[i]*b[i]%mod;
ntt.idft(a);
fill(a+m+1,a+N,0);
}
for(int i=0; i<N; i++)b[i]=1ll*b[i]*b[i]%mod;
ntt.idft(b);
fill(b+m+1,b+N,0);
ntt.dft(b);
}
printf("%d\n",a[m]);
return 0;
}
姥爷们赏瓶冰阔落吧~