隐藏
「JZOJ3872」圣诞树 - 同色三角形+点分治 | Bill Yang's Blog

路终会有尽头,但视野总能看到更远的地方。

0%

「JZOJ3872」圣诞树 - 同色三角形+点分治

题目大意

    圣诞节到了,小可可送给小薰一棵圣诞树。这棵圣诞树很奇怪,它是一棵多叉树,有$n$个点,$n-1$条边。它的每个结点都有一个权值。小可可和小薰想用这棵树玩一个游戏。
    定义$(s,e)$为树上从$s$到$e$的简单路径,我们可以记下在这条路径上经过的结点,定义这个结点序列为$S(s,e)$。
    我们按照如下方法定义这个序列$S(s,e)$的权值$G(S(s,e))$:假设这个序列中结点的权值为$Z_0,Z_1,\ldots,Z_{L-1}$,其中$L$为序列的长度,我们定义$G(S(s,e))=Z_0\times k^0+Z_1\times k^1+\cdots+ Z_{L-1}\times k^{L-1}$。
如果路径$(s,e)$满足$G(S(s,e))\equiv x\pmod y$,那么这条路径属于小可可,否则这条路径属于小薰。小可可和小薰很显然不希望这个游戏变得那么简单。小薰认为如果路径$(p_1,p_2)$和$(p_2,p_3)$都属于他,那么路径$(p_1,p_3)$也属于他,反之如果路径$(p_1,p_2)$和$(p_2,p_3)$都属于小可可,那么路径$(p_1,p_3)$也属于小可可。然而这个性质并不总是正确的。所以小薰想知道到底有多少三元组$(p_1,p_2,p_3)$满足这个性质。
小薰表示她看一眼就知道这道题怎么做了。你会吗?


题目分析

以前在计蒜客看到过这种同色三角形模型。
首先转化问题:
设$x\rightarrow y$的路径满足以上同余条件,则边权为$1$,否则边权为$0$。
我们要计算的是边权全为$0$和边权全为$1$的三角形数目。

因为全部的条件有三处限制,因此考虑容斥原理,用总数减去异色三角形的方案数即为答案。

因此对于所有的异色三角形,有以下$6$中情况(借用$HOWARLI$的图)。

不难发现,若我们只统计一个点引出去的异色条件,所得到的方案数刚好是异色三角形方案数的两倍(上图数一数)。

因此我们将统计三元组的任务转化为了统计二元组的数目,设一个点边权$1$的边入度为$in1$,出度为$out1$,其余与此类似。
则异色三元组个数为:

因此我们可以枚举根用树形动规预处理每个结点的$in1[],out1[]$,$in0[i]=n-in1[i],out0[i]=n-out1[i]$。

就可以用$O(n^2)$的算法解决问题了。

但是这样的复杂度是不能接受的,我们还要快速的求出$in1[]$与$out1[]$。

树上路径信息统计,考虑点分治:

我们将$x\rightarrow y$的合法路径进行分治,成为过根结点$u$的两条路径$[x,u]$与$(u,y]$,分别用$Dist1[x]$与$Dist2[y]$表示它们的路径权值和,用$Len[x]$表示路径长度。

则若$Dist1[x]+Dist2[y]\times K^{Len[x]}\equiv x\pmod y$,则在模意义下有:

这样成功将$x$与$y$分离,可以使用Hash表进行统计。

这样我们就可以用点分治统计了:
先从前往后加入儿子,先统计答案再加入,然后从后往前加入儿子,同样先统计答案再加入。接着递归子树即可。

注意有很多细节,比如根结点权值重复统计等问题。
反正我调了一下午。

听说Map常数大,大家都没有卡过去,不知道为什么我就卡过去了(所以大家都改写了其他数据结构跑得飞快)。


代码

50分树形动规

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include<algorithm>
#include<iostream>
#include<iomanip>
#include<cstring>
#include<cstdlib>
#include<vector>
#include<cstdio>
#include<cmath>
#include<queue>
using namespace std;

typedef long long LL;

inline const LL Get_Int() {
LL num=0,bj=1;
char x=getchar();
while(x<'0'||x>'9') {
if(x=='-')bj=-1;
x=getchar();
}
while(x>='0'&&x<='9') {
num=num*10+x-'0';
x=getchar();
}
return num*bj;
}

const int maxn=100005;
int root;
LL n,mod,x,K,ans=0,a[maxn],Val[maxn],In[maxn],Out[maxn];
vector<int>edges[maxn];

void AddEdge(int x,int y) {
edges[x].push_back(y);
}

void TreeDp(int Now,int fa,int depth,LL in,LL out) {
if(in==x)In[root]++;
if(out==x)Out[root]++;
for(int i=0; i<edges[Now].size(); i++) {
int Next=edges[Now][i];
if(Next==fa)continue;
TreeDp(Next,Now,depth+1,(in*K%mod+a[Next])%mod,(out+a[Next]*Val[depth+1]%mod)%mod);
}
}

int main() {
n=Get_Int();
mod=Get_Int();
K=Get_Int();
x=Get_Int();
for(int i=1; i<=n; i++)a[i]=Get_Int()%mod;
for(int i=1; i<n; i++) {
int x=Get_Int(),y=Get_Int();
AddEdge(x,y);
AddEdge(y,x);
}
Val[0]=1;
for(int i=1; i<=n; i++)Val[i]=Val[i-1]*K%mod;
for(int i=1; i<=n; i++) {
root=i;
TreeDp(i,-1,0,a[i],a[i]);
}
for(int i=1; i<=n; i++)ans+=In[i]*(n-Out[i])+(n-In[i])*Out[i]+2*(Out[i]*(n-Out[i])+(n-In[i])*In[i]);
printf("%lld\n",n*n*n-ans/2);
return 0;
}

100分点分治

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include<algorithm>
#include<iostream>
#include<iomanip>
#include<cstring>
#include<cstdlib>
#include<vector>
#include<cstdio>
#include<cmath>
#include<queue>
#include<map>
using namespace std;

typedef long long LL;

inline const LL Get_Int() {
LL num=0,bj=1;
char x=getchar();
while(x<'0'||x>'9') {
if(x=='-')bj=-1;
x=getchar();
}
while(x>='0'&&x<='9') {
num=num*10+x-'0';
x=getchar();
}
return num*bj;
}

int max(int a,int b) {
if(a>b)return a;
return b;
}

const int maxn=100005;
int root,Size[maxn],Maxson[maxn],Min,Core,vst[maxn],Len[maxn];
LL n,mod,X,K,ans=0,a[maxn],Val[maxn],inv[maxn],In[maxn],Out[maxn],Dist1[maxn],Dist2[maxn];
vector<int>edges[maxn];

void AddEdge(int x,int y) {
edges[x].push_back(y);
}

void Get_Size(int Now,int father) {
Size[Now]=1;
Maxson[Now]=0;
for(int Next:edges[Now]) {
if(Next==father||vst[Next])continue;
Get_Size(Next,Now);
Size[Now]+=Size[Next];
Maxson[Now]=max(Maxson[Now],Size[Next]);
}
}

void Get_Core(int Now,int father,int num) {
Maxson[Now]=max(Maxson[Now],Size[num]-Size[Now]);
if(Maxson[Now]<Min) {
Min=Maxson[Now];
Core=Now;
}
for(int Next:edges[Now]) {
if(Next==father||vst[Next])continue;
Get_Core(Next,Now,num);
}
}

void TreeDp(int Now,int fa,int depth,LL in,LL out) {
if(in==X)In[root]++,Out[Now]+=fa!=-1;
if(out==X)Out[root]++,In[Now]+=fa!=-1;
Dist1[Now]=in;
Dist2[Now]=(out-a[root]+mod)%mod;
Len[Now]=depth;
for(int Next:edges[Now]) {
if(Next==fa||vst[Next])continue;
TreeDp(Next,Now,depth+1,(in*K%mod+a[Next])%mod,(out+a[Next]*Val[depth+1]%mod)%mod);
}
}

map<int,int>Hash1,Hash2;

void Get_Ans(int Now,int fa) {
Out[Now]+=Hash1[(X-Dist1[Now]+mod)%mod*inv[Len[Now]]%mod];
In[Now]+=Hash2[Dist2[Now]];
for(int Next:edges[Now]) {
if(Next==fa||vst[Next])continue;
Get_Ans(Next,Now);
}
}

void Add(int Now,int fa) {
Hash1[Dist2[Now]]++;
Hash2[(X-Dist1[Now]+mod)%mod*inv[Len[Now]]%mod]++;
for(int Next:edges[Now]) {
if(Next==fa||vst[Next])continue;
Add(Next,Now);
}
}

void Dfs(int Now) {
Min=n;
Get_Size(Now,0);
Get_Core(Now,0,Now);
root=Now=Core;
vst[Now]=1;
Hash1.clear();
Hash2.clear();
TreeDp(Now,-1,0,a[Now],a[Now]);
for(int Next:edges[Now]) {
if(vst[Next])continue;
Get_Ans(Next,-1);
Add(Next,-1);
}
Hash1.clear();
Hash2.clear();
for(vector<int>::reverse_iterator it=edges[Now].rbegin(); it!=edges[Now].rend(); it++) {
int Next=*it;
if(vst[Next])continue;
Get_Ans(Next,-1);
Add(Next,-1);
}
for(int Next:edges[Now]) {
if(vst[Next])continue;
Dfs(Next);
}
}

LL Quick_Pow(LL a,LL b) {
LL ans=1;
while(b) {
if(b&1)ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}

int main() {
n=Get_Int();
mod=Get_Int();
K=Get_Int();
X=Get_Int();
for(int i=1; i<=n; i++) {
a[i]=Get_Int();
if(a[i]>=mod)a[i]%=mod;
}
for(int i=1; i<n; i++) {
int x=Get_Int(),y=Get_Int();
AddEdge(x,y);
AddEdge(y,x);
}
Val[0]=1;
inv[0]=1;
for(int i=1; i<=n; i++)Val[i]=Val[i-1]*K%mod,inv[i]=Quick_Pow(Val[i],mod-2);
Dfs(1);
for(int i=1; i<=n; i++)ans+=In[i]*(n-Out[i])+(n-In[i])*Out[i]+2*(Out[i]*(n-Out[i])+(n-In[i])*In[i]);
printf("%lld\n",n*n*n-ans/2);
return 0;
}

姥爷们赏瓶冰阔落吧~