原题传送门
期望dp
d
p
n
,
m
=
1
+
n
n
+
m
d
p
n
+
m
−
1
,
1
+
m
n
+
m
d
p
n
,
m
−
1
dp_{n,m}=1+\frac{n}{n+m}dp_{n+m-1,1}+\frac{m}{n+m}dp_{n,m-1}
dpn,m=1+n+mndpn+m−1,1+n+mmdpn,m−1
然后发现无法实现
经过部分分
m
=
0
m=0
m=0启发
又看见了
n
<
=
1
0
14
,
m
<
=
1
0
6
n<=10^{14},m<=10^6
n<=1014,m<=106
猜想时间复杂度可能只与
m
m
m有关
d
p
n
,
1
=
1
+
n
n
+
1
d
p
n
,
1
+
1
n
+
1
d
p
n
,
0
dp_{n,1}=1+\frac{n}{n+1}dp_{n,1}+\frac{1}{n+1}dp_{n,0}
dpn,1=1+n+1ndpn,1+n+11dpn,0
−
−
>
d
p
n
,
1
=
(
n
+
1
)
+
d
p
n
,
0
(
1
式
)
-->dp_{n,1}=(n+1)+dp_{n,0}(1式)
−−>dpn,1=(n+1)+dpn,0(1式)
d
p
n
,
0
=
1
+
d
p
n
−
1
,
1
(
2
式
)
dp_{n,0}=1+dp_{n-1,1}(2式)
dpn,0=1+dpn−1,1(2式)
将1式代到2式里面,得到
d
p
n
,
0
=
(
n
+
1
)
+
d
p
n
−
1
,
0
dp_{n,0}=(n+1)+dp_{n-1,0}
dpn,0=(n+1)+dpn−1,0
将2式代到1式里面,得到
d
p
n
,
1
=
(
n
+
2
)
+
d
p
n
−
1
,
1
dp_{n,1}=(n+2)+dp_{n-1,1}
dpn,1=(n+2)+dpn−1,1
这样我把
d
p
n
,
0
,
d
p
n
,
1
dp_{n,0},dp_{n,1}
dpn,0,dpn,1的通项都求出来了
然后求通项
d
p
n
,
0
=
2
+
3
+
.
.
.
+
(
n
+
1
)
=
n
(
n
+
3
)
2
=
n
3
+
3
n
2
dp_{n,0}=2+3+...+(n+1)=\frac{n(n+3)}{2}=\frac{n^3+3n}{2}
dpn,0=2+3+...+(n+1)=2n(n+3)=2n3+3n
d
p
n
,
1
=
2
+
3
+
.
.
.
+
(
n
+
2
)
−
1
=
(
n
+
1
)
(
n
+
4
)
2
−
1
=
n
2
+
5
n
+
2
2
dp_{n,1}=2+3+...+(n+2)-1=\frac{(n+1)(n+4)}{2}-1=\frac{n^2+5n+2}{2}
dpn,1=2+3+...+(n+2)−1=2(n+1)(n+4)−1=2n2+5n+2
这样一开始那么转移可以写成
d
p
n
,
m
=
n
n
+
m
(
n
+
m
−
1
)
2
+
5
(
n
+
m
−
1
)
+
2
2
+
m
n
+
m
d
p
n
,
m
−
1
+
1
dp_{n,m}=\frac{n}{n+m}\frac{(n+m-1)^2+5(n+m-1)+2}{2}+\frac{m}{n+m}dp_{n,m-1}+1
dpn,m=n+mn2(n+m−1)2+5(n+m−1)+2+n+mmdpn,m−1+1
可以把第一维弄掉
变成
d
p
i
=
n
n
+
i
(
n
+
i
−
1
)
2
+
5
(
n
+
i
−
1
)
+
2
2
+
n
n
+
i
d
p
i
−
1
+
1
dp_{i}=\frac{n}{n+i}\frac{(n+i-1)^2+5(n+i-1)+2}{2}+\frac{n}{n+i}dp_{i-1}+1
dpi=n+in2(n+i−1)2+5(n+i−1)+2+n+indpi−1+1
复杂度
O
(
m
)
O(m)
O(m)
Code:
#include <bits/stdc++.h>
#define LL long long
using namespace std;
const LL qy = 998244353;
LL n, m;
inline LL read(){
LL s = 0, w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
LL ksm(LL n, LL k){
LL s = 1;
for (; k; k >>= 1, n = n * n % qy) if (k & 1) s = s * n % qy;
return s;
}
int main(){
n = read() % qy, m = read();
LL inv2 = ksm(2, qy - 2);
LL ans = n % qy * (n + 3) % qy * inv2 % qy;
for (LL i = 1; i <= m; ++i){
ans = ans * i % qy * ksm((i + n) % qy, qy - 2) % qy;
ans = (ans + 1) % qy;
ans = (ans + n * ksm((n + i) % qy, qy - 2) % qy * inv2 % qy * ((i + n - 1) % qy * ((i + n - 1) % qy) % qy + 5LL * (i + n - 1) % qy + 2LL) % qy) % qy;
}
printf("%lld\n", ans);
return 0;
}