题面
题解
设
f
n
f_n
fn 表示叶子数为
n
n
n 的答案,容易得出以下式子:
f
0
=
0
,
f
1
=
1
f
n
=
∑
i
=
1
n
c
i
f
i
f
n
−
i
\begin{aligned} &f_0=0,f_1=1\\ &f_n=\sum_{i=1}^n c_if_if_{n-i} \end{aligned}
f0=0,f1=1fn=i=1∑ncififn−i
其中
c
s
i
=
v
i
c_{s_i}=v_i
csi=vi,其余的
c
i
c_i
ci 均为
A
A
A。
注意到
k
≤
10
k\leq 10
k≤10 很小,
c
i
c_i
ci 中只有个别值不同,所以考虑将所有的
v
i
v_i
vi 减去
A
A
A,然后得到下面这个递推式:
f
n
=
∑
i
=
1
n
A
f
i
f
n
−
i
+
∑
i
=
1
k
v
i
f
s
i
f
n
−
s
i
f_n=\sum_{i=1}^nAf_if_{n-i}+\sum_{i=1}^kv_if_{s_i}f_{n-s_i}
fn=i=1∑nAfifn−i+i=1∑kvifsifn−si
设
f
i
f_i
fi 的生成函数为
F
(
x
)
F(x)
F(x),设
P
(
x
)
=
∑
i
=
1
k
v
i
f
s
i
x
s
i
P(x)=\sum\limits_{i=1}^kv_if_{s_i}x^{s_i}
P(x)=i=1∑kvifsixsi,那么有:
F
(
x
)
≡
A
F
(
x
)
2
+
P
(
x
)
F
(
x
)
+
x
(
m
o
d
x
n
+
1
)
F(x)\equiv AF(x)^2+P(x)F(x)+x\pmod {x^{n+1}}
F(x)≡AF(x)2+P(x)F(x)+x(modxn+1)
(
+
x
+x
+x 是因为
A
F
(
x
)
2
+
P
(
x
)
F
(
x
)
AF(x)^2+P(x)F(x)
AF(x)2+P(x)F(x) 只计算了
n
≥
2
n\geq 2
n≥2 的系数,需要初始化
f
1
=
1
f_1=1
f1=1)
注意到 F ( x ) F(x) F(x) 的零次项为 0 0 0,所以 P ( x ) P(x) P(x) 的 n n n 次项系数对 F ( x ) F(x) F(x) 的 n n n 次项系数没有贡献。所以我们只保留 P ( x ) P(x) P(x) 的 n − 1 n-1 n−1 次项系数再代入等号右边的 P ( x ) P(x) P(x) 其实是等价的。其本质就是递推。
所以我们记
P
′
(
x
)
=
P
(
x
)
m
o
d
x
n
P'(x)=P(x) \bmod x^n
P′(x)=P(x)modxn,也会有:
F
(
x
)
≡
A
F
(
x
)
2
+
P
′
(
x
)
F
(
x
)
+
x
(
m
o
d
x
n
+
1
)
F(x)\equiv AF(x)^2+P'(x)F(x)+x\pmod{x^{n+1}}
F(x)≡AF(x)2+P′(x)F(x)+x(modxn+1)
解得:
F
(
x
)
≡
1
−
P
′
(
x
)
±
(
1
−
P
′
(
x
)
)
2
−
4
A
x
2
A
(
m
o
d
x
n
+
1
)
F(x)\equiv \dfrac{1-P'(x)\pm\sqrt{\big(1-P'(x)\big)^2-4Ax}}{2A}\pmod {x^{n+1}}
F(x)≡2A1−P′(x)±(1−P′(x))2−4Ax(modxn+1)
令
Q
(
x
)
=
(
1
−
P
′
(
x
)
)
2
−
4
A
x
Q(x)=\big(1-P'(x)\big)^2-4Ax
Q(x)=(1−P′(x))2−4Ax,
G
(
x
)
=
Q
(
x
)
G(x)=\sqrt{Q(x)}
G(x)=Q(x)。注意到
q
0
=
1
q_0=1
q0=1,那么
g
0
=
1
g_0=1
g0=1。又由于
F
(
x
)
F(x)
F(x) 常数项为
0
0
0,所以应该取负号,故:
F
(
x
)
≡
1
−
P
′
(
x
)
−
(
1
−
P
′
(
x
)
)
2
−
4
A
x
2
A
(
m
o
d
x
n
+
1
)
F(x)\equiv \dfrac{1-P'(x)-\sqrt{\big(1-P'(x)\big)^2-4Ax}}{2A}\pmod {x^{n+1}}
F(x)≡2A1−P′(x)−(1−P′(x))2−4Ax(modxn+1)
注意这条式子里
p
n
p_n
pn 看似会对
f
n
f_n
fn 的取值有影响,但我们推的式子是正确的,说明
p
n
p_n
pn 实际上被抵消掉了,它对
f
n
f_n
fn 的取值没有影响。
我们只需要得到
F
(
x
)
F(x)
F(x) 的
n
n
n 次项,那我们就需要
g
n
g_n
gn,考虑推导
G
(
x
)
=
Q
(
x
)
G(x)=\sqrt{Q(x)}
G(x)=Q(x) 实现快速算
g
n
g_n
gn,两边求导得:
G
′
(
x
)
=
Q
′
(
x
)
2
Q
(
x
)
G
′
(
x
)
Q
(
x
)
=
Q
′
(
x
)
Q
(
x
)
2
=
Q
′
(
x
)
G
(
x
)
2
\begin{aligned} G'(x)&=\dfrac{Q'(x)}{2\sqrt{Q(x)}}\\ G'(x)Q(x)&=\dfrac{Q'(x)\sqrt{Q(x)}}{2}=\dfrac{Q'(x)G(x)}{2} \end{aligned}
G′(x)G′(x)Q(x)=2Q(x)Q′(x)=2Q′(x)Q(x)=2Q′(x)G(x)
提取
x
n
x^n
xn 的系数:
∑
i
=
0
n
(
n
−
i
+
1
)
g
n
−
i
+
1
q
i
=
1
2
∑
i
=
0
n
(
i
+
1
)
q
i
+
1
g
n
−
i
(
n
+
1
)
g
n
+
1
q
0
+
∑
i
=
1
n
(
n
−
i
+
1
)
g
n
−
i
+
1
q
i
=
1
2
∑
i
=
1
n
+
1
i
q
i
g
n
−
i
+
1
(
n
+
1
)
g
n
+
1
=
∑
i
=
1
n
+
1
1
2
i
q
i
g
n
−
i
+
1
−
(
n
−
i
+
1
)
g
n
−
i
+
1
q
i
=
∑
i
=
1
n
+
1
q
i
g
n
−
i
+
1
(
3
2
i
−
n
−
1
)
\begin{aligned} \sum_{i=0}^n(n-i+1)g_{n-i+1}q_i&=\dfrac{1}{2}\sum_{i=0}^n(i+1)q_{i+1}g_{n-i}\\ (n+1)g_{n+1}q_0+\sum_{i=1}^n(n-i+1)g_{n-i+1}q_i&=\dfrac{1}{2}\sum_{i=1}^{n+1}iq_ig_{n-i+1}\\ (n+1)g_{n+1}&=\sum_{i=1}^{n+1}\dfrac{1}{2}iq_ig_{n-i+1}-(n-i+1)g_{n-i+1}q_i\\ &=\sum_{i=1}^{n+1}q_ig_{n-i+1}\left(\dfrac{3}{2}i-n-1\right) \end{aligned}
i=0∑n(n−i+1)gn−i+1qi(n+1)gn+1q0+i=1∑n(n−i+1)gn−i+1qi(n+1)gn+1=21i=0∑n(i+1)qi+1gn−i=21i=1∑n+1iqign−i+1=i=1∑n+121iqign−i+1−(n−i+1)gn−i+1qi=i=1∑n+1qign−i+1(23i−n−1)
所以
n
g
n
=
∑
i
=
1
n
q
i
g
n
−
i
(
3
2
i
−
n
)
ng_n=\sum\limits_{i=1}^nq_ig_{n-i}\left(\dfrac{3}{2}i-n\right)
ngn=i=1∑nqign−i(23i−n)。
注意到 Q ( x ) Q(x) Q(x) 只有 O ( k 2 ) O(k^2) O(k2) 项有值,所以如果知道 g 1 ∼ g n − 1 g_1\sim g_{n-1} g1∼gn−1, g n g_n gn 就可以暴力算。
那么我们考虑递推:
- 假设我们已经知道了 P ′ ( x ) = P ( x ) m o d x n P'(x)=P(x) \bmod {x^{n}} P′(x)=P(x)modxn,即已经知道了 p 1 ∼ p n − 1 p_1\sim p_{n-1} p1∼pn−1。
- 我们用 P ′ ( x ) P'(x) P′(x) 暴力计算出 Q ( x ) Q(x) Q(x),那么我们就知道了 q 1 ∼ q n − 1 q_1\sim q_{n-1} q1∼qn−1 和 q n ′ q'_n qn′。(单次时间复杂度 O ( k 2 log k 2 ) O(k^2\log k^2) O(k2logk2))
- 利用 q 1 ∼ q n − 1 q_1\sim q_{n-1} q1∼qn−1 和 q n ′ q'_n qn′ 计算出 g n ′ g'_n gn′,再通过 g n ′ g'_n gn′ 得到 f n f_n fn。(单次时间复杂度 O ( k 2 ) O(k^2) O(k2))
- 通过 f n f_n fn 更新 P ( x ) P(x) P(x),然后得到 p 1 ∼ p n p_1\sim p_n p1∼pn,注意记得更新 q n q_n qn 和 g n g_n gn。(单次时间复杂度 O ( k 2 log k 2 ) O(k^2\log k^2) O(k2logk2))
q n ′ q'_n qn′ 和 g n ′ g_n' gn′ 的意思是它们并不是真正的 q n q_n qn 和 g n g_n gn,但是通过 q n ′ q'_n qn′ 和 g n ′ g'_n gn′ 也能算出 f n f_n fn,记得最后要用 f n f_n fn 重新得到真正的 q n q_n qn 和 g n g_n gn。
注意 P ( x ) P(x) P(x) 只有 O ( k ) O(k) O(k) 次更新,所以上述 2,4 步骤实际上只会执行 O ( k ) O(k) O(k) 次。
所以总时间复杂度为 O ( n k 2 + k 3 log k 2 ) O(nk^2+k^3\log k^2) O(nk2+k3logk2)。
感觉这道题还是有点绕的,需要自己手推。
代码如下:
#include<bits/stdc++.h>
#define K 15
#define N 1000010
#define re register
using namespace std;
namespace modular
{
const int mod=1000000007,inv2=500000004;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline void Add(int &x,int y){x=(x+y>=mod?x+y-mod:x+y);}
const int cc=mul(3,inv2);
}using namespace modular;
inline int poww(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
struct data
{
int p,v;
data(){};
data(int a,int b){p=a,v=b;}
};
typedef vector<data> poly;
int n,k,A,val[N];
int f[N],g[N];
poly p,q;
inline void work(const poly &a,poly &ans)
{
static map<int,int>mp;
mp.clear();
mp[1]=dec(0,mul(4,A));
for(int i=0,sa=a.size();i<sa;i++)
for(int j=0;j<sa;j++)
Add(mp[a[i].p+a[j].p],mul(a[i].v,a[j].v));
ans.clear();
for(map<int,int>::iterator it=mp.begin();it!=mp.end();it++)
ans.push_back(data(it->first,it->second));
}
inline void getg(int n)
{
int ans=0;
for(re int i=0,s=q.size();i<s;i++)
{
if(q[i].p<1) continue;
if(q[i].p>n) break;
ans=add(ans,mul(dec(mul(q[i].p,cc),n),mul(q[i].v,g[n-q[i].p])));
}
g[n]=mul(ans,poww(n,mod-2));
}
int main()
{
n=read(),k=read(),A=read();
memset(val,-1,sizeof(val));
for(int i=1;i<=k;i++)
{
int s=read(),v=read();
val[s]=dec(v,A);
}
f[1]=g[0]=1;
p.push_back(data(0,dec(0,1)));
if(~val[1]) p.push_back(data(1,val[1]));
work(p,q);
getg(1);
const int c3=poww(mul(2,A),mod-2);
for(re int now=2;now<=n;now++)
{
getg(now);
f[now]=mul(dec(0,g[now]),c3);
if(~val[now])
{
p.push_back(data(now,mul(val[now],f[now])));
work(p,q);
getg(now);
}
}
printf("%d\n",f[n]);
return 0;
}
/*
5 1 1
2 2
*/