「SDOI2015」序列统计 - 生成函数+快速幂+NTT | Bill Yang's Blog

「SDOI2015」序列统计 - 生成函数+快速幂+NTT

题目大意

    小C有一个集合$S$,里面的元素都是小于$M$的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为$N$的数列,数列中的每个数都属于集合$S$。
    小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数$x$,求所有可以生成出的,且满足数列中所有数的乘积$\bmod M$的值等于$x$的不同的数列的有多少个。小C认为,两个数列$\lbrace A_i\rbrace$和$\lbrace B_i\rbrace$不同,当且仅当至少存在一个整数$i$,满足$A_i\neq B_i$。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案$\bmod\,1004535809$的值就可以了。


题目分析

首先,所有数的乘积不好处理,因为题目满足一些特别的限制:

  • $1\le x\le M-1$
  • $M$为质数

因此我们可以通过取离散对数将问题数的乘积转化为数的和。
具体的方法是:
找到$M$的原根$g$,由原根性质可知:$g^i(0\le i\le M-2)$取遍$[1,M-1]$所有数。
因此我们可以找到一个从$[1,M-1]$到$[0,M-2]$的离散对数映射,通过映射可以将乘积转化为和,我们将映射记为$idx[]$。
数论中也将这称为指标。

接下来的问题变为了:有$\left|S\right|$个数,取出$n$个数(可以重复取),使得取出数的和为$idx[x]$的方案数是多少。
注意原题数据有错,$S$中有值为$0$的元素,因为$x\neq0$,故可以直接无视。

对于初学者:
这是一个类似完全背包的问题,但$n$范围很大,不可能使用动态规划。
对于数论高手:
这是一个简单的生成函数的问题。

因为每个数$num$只出现一次,故我们将$A(X)$的系数表示中所有$idx[num]$置为$1$。
本题是生成函数自乘,只需要将其自乘$N$次,系数表示中第$idx[x]$即为答案。
自乘可以使用快速幂来解决,快速幂每次乘法使用NTT计算。
当然也可以使用多项式$exp+lnp$的方法计算。

注意因为指标的原因,每次NTT后生成函数中$[M-1,2M-1]$可能存在值,但这是不满足映射关系的,我们将其重新映射到$[0,M-2]$中。(相当于对下标取模)


代码

代码中的快速幂没有进行常数优化,没有必要每次都执行idft,因此可能会很慢。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include<algorithm>
#include<iostream>
#include<iomanip>
#include<cstring>
#include<cstdlib>
#include<climits>
#include<vector>
#include<cstdio>
#include<cmath>
#include<queue>
using namespace std;

typedef long long LL;

inline const int Get_Int() {
int num=0,bj=1;
char x=getchar();
while(x<'0'||x>'9') {
if(x=='-')bj=-1;
x=getchar();
}
while(x>='0'&&x<='9') {
num=num*10+x-'0';
x=getchar();
}
return num*bj;
}

const int maxn=32768+5;
const LL mod=1004535809;
LL g=3;

LL Quick_Pow(LL a,LL b,LL p=mod) {
LL sum=1;
for(; b; b>>=1,a=a*a%p)if(b&1)sum=sum*a%p;
return sum;
}

LL inv(LL x) {
return Quick_Pow(x,mod-2);
}

vector<int> ppt;

int find_root(LL x) {
LL tmp=x-1;
for(int i=2; i<=sqrt(tmp); i++)
if(tmp%i==0) {
ppt.push_back(i);
while(tmp%i==0)tmp/=i;
}
for(int g=2; g<x; g++) {
bool bj=1;
for(int t:ppt)
if(Quick_Pow(g,(x-1)/t,x)==1) {
bj=0;
break;
}
if(bj)return g;
}
}

struct NumberTheoreticTransform {
int n;
LL omega[maxn],iomega[maxn];

void init(int n) {
this->n=n;
int x=Quick_Pow(g,(mod-1)/n);
omega[0]=iomega[0]=1;
for(int i=1; i<n; i++) {
omega[i]=omega[i-1]*x%mod;
iomega[i]=inv(omega[i]);
}
}

void transform(LL* a,LL* omega) {
int k=log2(n);
for(int i=0; i<n; i++) {
int t=0;
for(int j=0; j<k; j++)if(i&(1<<j))t|=(1<<(k-j-1));
if(i<t)swap(a[i],a[t]);
}
for(int len=2; len<=n; len*=2) {
int mid=len>>1;
for(LL* p=a; p!=a+n; p+=len)
for(int i=0; i<mid; i++) {
LL t=omega[n/len*i]*p[mid+i]%mod;
p[mid+i]=(p[i]-t+mod)%mod;
p[i]=(p[i]+t)%mod;
}
}
}

void dft(LL* a) {
transform(a,omega);
}

void idft(LL* a) {
transform(a,iomega);
LL x=inv(n);
for(int i=0; i<n; i++)a[i]=a[i]*x%mod;
}
} ntt;

void Multiply(const LL* a1,const int n1,const LL* a2,const int n2,LL* ans) {
int n=1;
while(n<n1+n2)n<<=1;
ntt.init(n);
static LL b1[maxn],b2[maxn];
fill(b1,b1+n,0);
fill(b2,b2+n,0);
copy(a1,a1+n,b1);
copy(a2,a2+n,b2);
ntt.dft(b1);
ntt.dft(b2);
for(int i=0; i<n; i++)b1[i]=b1[i]*b2[i]%mod;
ntt.idft(b1);
for(int i=n1; i<=n1+n2; i++)b1[i-n1]=(b1[i-n1]+b1[i])%mod;
for(int i=0; i<n1; i++)ans[i]=b1[i];
}

LL n,m,x,S,idx[maxn],f[maxn],ans[maxn];

int main() {
// g=find_root(mod);
n=Get_Int();
m=Get_Int();
x=Get_Int();
S=Get_Int();
LL g=find_root(m),tmp=1;
for(int i=0; i<m-1; i++) {
idx[tmp]=i;
tmp=tmp*g%m;
}
for(int i=1; i<=S; i++) {
int x=Get_Int();
if(x==0)continue;
f[idx[x]]=1;
}
ans[0]=1;
for(; n; Multiply(f,m-1,f,m-1,f),n>>=1)
if(n&1)Multiply(ans,m-1,f,m-1,ans);
printf("%lld\n",ans[idx[x]]);
return 0;
}

姥爷们赏瓶冰阔落吧~
0%