题目
https://acm.bnu.edu.cn/v3/problem_show.php?pid=51637
http://www.lydsy.com/JudgeOnline/problem.php?id=4587
题意
有一个
(n+1)×(m+1)
(
n
+
1
)
×
(
m
+
1
)
的网格,初始在点
(0,0)
(
0
,
0
)
,每次可以从
(x,y)
(
x
,
y
)
走到
(x+1,y)
(
x
+
1
,
y
)
或
(x,y+1)
(
x
,
y
+
1
)
或
(x+1,y+1)
(
x
+
1
,
y
+
1
)
。有
k
k
个点不能到达,问从初始点走到横坐标为
n
n
或纵坐标为的方案数模
p
p
,不一定是质数。注意当横坐标为
n
n
或纵坐标为时不能再走动。
1≤n≤109,1≤m≤104,0≤k≤10,1≤p<231,∑m≤105
1
≤
n
≤
10
9
,
1
≤
m
≤
10
4
,
0
≤
k
≤
10
,
1
≤
p
<
2
31
,
∑
m
≤
10
5
题解
1.一个暴力
定义
dp[i][j]
d
p
[
i
]
[
j
]
表示从
(0,0)
(
0
,
0
)
走到
(i,j)
(
i
,
j
)
的方案数,定义横坐标为
n
n
或纵坐标为的
(i,j)
(
i
,
j
)
为终止状态。
初始状态
dp[0][0]=1
d
p
[
0
]
[
0
]
=
1
。
如果
(i,j)
(
i
,
j
)
不能到达则
dp[i][j]=0
d
p
[
i
]
[
j
]
=
0
。
如果
(i,j)
(
i
,
j
)
能到达,
dp[i][j]
d
p
[
i
]
[
j
]
包含
dp[i−1][j−1]
d
p
[
i
−
1
]
[
j
−
1
]
,如果
(i−1,j)
(
i
−
1
,
j
)
不是终止状态则
dp[i][j]
d
p
[
i
]
[
j
]
也包含
dp[i−1][j]
d
p
[
i
−
1
]
[
j
]
,如果
(i,j−1)
(
i
,
j
−
1
)
不是终止状态则
dp[i][j]
d
p
[
i
]
[
j
]
也包含
dp[i][j−1]
d
p
[
i
]
[
j
−
1
]
。
答案是所有终止状态的
dp[i][j]
d
p
[
i
]
[
j
]
之和。
状态转移,只涉及加减,模运算很好套用,时间复杂度
O(nm)
O
(
n
m
)
。
int ans = 0, f[n + 1][m + 1];
for(int i = 0; i <= n; ++i)
for(int j = 0; j <= m; ++j)
{
if(canreach[i][j] == -1)
{
f[i][j] = 0;
continue;
}
if(i == 0 && j == 0)
{
f[i][j] = 1;
continue;
}
long long tmp = 0;
if(i > 0 && j < m)
tmp += f[i - 1][j];
if(i < n && j > 0)
tmp += f[i][j - 1];
if(i > 0 && j > 0)
tmp += f[i - 1][j - 1];
f[i][j] = tmp % mod;
if(i == n || j == m)
ans += f[i][j];
}
用于对拍
2. k=0 k = 0 的情况
考虑从
(0,0)
(
0
,
0
)
走到
(i,j)
(
i
,
j
)
的方案数,然后直接计算
i=n
i
=
n
和
j=m
j
=
m
的情况。
设从
(0,0)
(
0
,
0
)
走到
(i,j)
(
i
,
j
)
的一种方案使用了
a
a
次,
b
b
次,
c
c
次同时
y+1
y
+
1
,则有
对应的方案数是
考虑 i=n i = n 的情况,则 0≤j≤m 0 ≤ j ≤ m 的情况为
那么枚举 c c 就可以了,不用枚举。对于 0≤i≤n,j=m 0 ≤ i ≤ n , j = m 的情况也这么分析一下。
不过这个题不是求无阻碍到 (n,∗) ( n , ∗ ) 或 (∗,m) ( ∗ , m ) 的方案数,因为 (n,i) ( n , i ) 走不到 (n,i+1) ( n , i + 1 ) 。
但是到终止状态前的一个点是无阻碍的,枚举前一个点算贡献即可。
考虑怎么计算 (nc)modp ( n c ) mod p 。一种方法是将 p p 分解成质因子的幂次, O(c) O ( c ) 枚举分子分母,通过消 p p 因子的方法算出模值,再用中国剩余定理合并。还有一种方法是不分解成质因子的幂次,直接在模意义下做,记录与 p p 互质的部分的值,记录不与互质的部分每个质因子的幂次,两者复杂度相近,但编程复杂度低。
由于还要枚举 c c ,看上去做法是的,实际上可以由 (nc) ( n c ) 在 O(logp) O ( log p ) 的时间内推到 (nc+1) ( n c + 1 ) (乘一个数除一个数),所以复杂度是 O(mlogp) O ( m log p ) 的。
3. k>0 k > 0 的情况
3.1 我会容斥?
”不经过任意一个坏点的方案数“可以通过计算”至少经过其中
k
k
个已知坏点的方案数“容斥得到。
枚举坏点集合的子集
T
T
,计算从走过
T
T
中每个坏点后走到终止状态的方案数,该方案数对答案的贡献是
(−1)|T|⋅f(T)
(
−
1
)
|
T
|
⋅
f
(
T
)
,这个方案不为
0
0
的情况一定是坏点可以按坐标排序后顺次经过,需要计算从一个点走到
(i′,j′)
(
i
′
,
j
′
)
的方案数,以及从最后一个坏点走到终止状态的方案数,前者依然可以枚举
c
c
来算。
看上去时间复杂度是的,容易构造数据使其超时。
3.2 我会容斥!
对于每个坏点
i
i
,定义表示走到这个点且之前没碰到坏点的方案数,这个就好容斥了,枚举在这个坏点左上角的坏点
j
j
,不合法的方案是乘上从
(aj,bj)
(
a
j
,
b
j
)
到
(ai,bi)
(
a
i
,
b
i
)
的方案数,这些方案里遇到的第一个坏点是
j
j
,所以不重不漏,最后用从到
(ai,bi)
(
a
i
,
b
i
)
的方案数减一下即可。
答案也可以这么统计,算从
(0,0)
(
0
,
0
)
到终止状态的方案,减去所有
f[i]
f
[
i
]
乘上从
(ai,bi)
(
a
i
,
b
i
)
到终止状态的方案即可。
看上去时间复杂度是
O(k2mlogp)
O
(
k
2
m
log
p
)
的,好好写就不会超时了。
注意坏点在终止状态上的情况。写个暴力对拍一下就可以发现问题
代码
#include <cstdio>
#include <algorithm>
typedef long long LL;
const int maxp = 46341, maxc = 11, maxk = 11, maxe = 51;
void exgcd(int a, int b, int &x, int &y)
{
if(!b)
{
x = 1;
y = 0;
return;
}
exgcd(b, a % b, y, x);
y -= a / b * x;
}
int mod_inv(int x, int p)
{
int s, t;
exgcd(x, p, s, t);
return s < 0 ? s + p : s;
}
int tot, prime[maxp], fir[maxp], mod, cnt, fact[maxc], Exp[maxc], Coeff, Lim[maxc], Pw[maxc][maxe];
inline void mod_dec(int &x, int y, int c = 1)
{
while(c--)
if((x -= y) < 0)
x += mod;
}
void init()
{
Coeff = 1;
for(int i = 0; i < cnt; ++i)
{
Exp[i] = Lim[i] = 0;
Pw[i][0] = 1;
}
}
void update(int val, int flag)
{
for(int i = 0; i < cnt; ++i)
for( ; val % fact[i] == 0; Exp[i] += flag, val /= fact[i]);
Coeff = (LL)Coeff * (flag == 1 ? val : mod_inv(val, mod)) % mod;
}
int query()
{
int ret = Coeff;
for(int j = 0; j < cnt && ret; ++j)
{
if(!Exp[j])
continue;
for( ; Lim[j] < Exp[j]; ++Lim[j])
Pw[j][Lim[j] + 1] = (LL)Pw[j][Lim[j]] * fact[j] % mod;
ret = (LL)ret * Pw[j][Exp[j]] % mod;
}
return ret;
}
int calc_1(int n, int m)
{
if(n > m)
std::swap(n, m);
if(n < 0)
return 0;
else if(!n)
return 1;
int ret = 0;
init();
for(int i = 1; i <= n; ++i)
{
update(m + i, 1);
update(i, -1);
}
for(int i = 0; i <= n; ++i)
{
mod_dec(ret, mod - query());
if(i == n)
break;
update(n - i, 1);
update(m - i, 1);
update(n + m - i, -1);
update(i + 1, -1);
}
return ret;
}
int calc_2(int n, int m)
{
if(n > m)
std::swap(n, m);
if(n < 0)
return 0;
else if(!n)
return 1;
int ret = 0;
--n;
--m;
init();
for(int i = 1; i <= n; ++i)
{
update(m + i, 1);
update(i, -1);
}
update(n + m + 1, 1);
for(int i = 0; i <= n; ++i)
{
update(n + 1, -1);
mod_dec(ret, mod - query(), 2);
update(n + 1, 1);
update(m + 1, -1);
mod_dec(ret, mod - query(), 2);
update(m + 1, 1);
update(n + m + 1 - i, -1);
mod_dec(ret, query());
if(i == n)
break;
update(n - i, 1);
update(m - i, 1);
update(i + 1, -1);
}
return ret;
}
int t, n, m, k, f[maxk], ans;
std::pair<int, int> lim[maxk];
int main()
{
for(int i = 2; i < maxp; ++i)
{
if(!fir[i])
prime[tot++] = fir[i] = i;
for(int j = 0, k; (k = i * prime[j]) < maxp; ++j)
{
fir[k] = prime[j];
if(fir[i] == prime[j])
break;
}
}
scanf("%d", &t);
while(t--)
{
scanf("%d%d%d%d", &n, &m, &k, &mod);
int tmp = mod;
cnt = 0;
for(int i = 0; i < tot && prime[i] * prime[i] <= tmp; ++i)
if(tmp % prime[i] == 0)
for(fact[cnt++] = prime[i]; tmp % prime[i] == 0; tmp /= prime[i]);
if(tmp > 1)
fact[cnt++] = tmp;
for(int i = 0; i < k; ++i)
scanf("%d%d", &lim[i].first, &lim[i].second);
std::sort(lim, lim + k);
k = std::unique(lim, lim + k) - lim;
for(int i = 0; i < k; ++i)
{
int &ai = lim[i].first, &bi = lim[i].second;
if(ai < n && bi < m)
{
f[i] = calc_1(ai, bi);
for(int j = 0; j < i; ++j)
{
int &aj = lim[j].first, &bj = lim[j].second;
int coeff = calc_1(ai - aj, bi - bj);
if(coeff)
f[i] = (f[i] - (LL)coeff * f[j]) % mod;
}
}
else
{
f[i] = calc_1(ai - 1, bi - 1);
if(ai < n)
mod_dec(f[i], mod - calc_1(ai, bi - 1));
if(bi < m)
mod_dec(f[i], mod - calc_1(ai - 1, bi));
for(int j = 0; j < i; ++j)
{
int &aj = lim[j].first, &bj = lim[j].second;
int coeff = calc_1(ai - 1 - aj, bi - 1 - bj);
if(ai < n)
mod_dec(coeff, mod - calc_1(ai - aj, bi - 1 - bj));
if(bi < m)
mod_dec(coeff, mod - calc_1(ai - 1 - aj, bi - bj));
if(coeff)
f[i] = (f[i] - (LL)coeff * f[j]) % mod;
}
}
}
ans = calc_2(n, m);
for(int i = 0; i < k; ++i)
{
int coeff = calc_2(n - lim[i].first, m - lim[i].second);
if(coeff)
ans = (ans - (LL)coeff * f[i]) % mod;
}
if(ans < 0)
ans += mod;
printf("%d\n", ans);
}
return 0;
}