分析
设第一个数为 x x x,则第二个数为 x + d 1 x+d_1 x+d1,第三个数为 x + d 1 + d 2 x+d_1+d_2 x+d1+d2 …。这里的 d 1 d_1 d1, d 2 d_2 d2表示 a a a或者 − b -b −b,所以这个数列为:
x x x, x + d 1 x + d_1 x+d1, x + d 1 + d 2 x + d_1 + d_2 x+d1+d2, x + d 1 + d 2 + d 3 x + d_1 + d_2 + d_3 x+d1+d2+d3, …, x + d 1 + d 2 + . . . + d n − 1 x + d_1 + d_2 + ... + d_{n-1} x+d1+d2+...+dn−1,又因为数列之和为s,所以转化成:
n ∗ x + ( n − 1 ) ∗ d 1 + ( n − 2 ) ∗ d 2 + ( n − 3 ) ∗ d 3 + . . . + d n − 1 = s n * x + (n-1) * d_1 + (n-2) * d_2 + (n-3) * d_3 + ... + d_{n-1} = s n∗x+(n−1)∗d1+(n−2)∗d2+(n−3)∗d3+...+dn−1=s,再在一步转化:
s − [ ( n − 1 ) ∗ d 1 + ( n − 2 ) ∗ d 2 + ( n − 3 ) ∗ d 3 + . . . + d n − 1 ] n = x \frac{s - [(n-1) * d_1 + (n-2) * d_2 + (n-3) * d_3 + ...+ d_{n-1}]}{n} = x ns−[(n−1)∗d1+(n−2)∗d2+(n−3)∗d3+...+dn−1]=x
因为x是任意整数,所以又转化成:
s s s与 ( n − 1 ) ∗ d 1 + ( n − 2 ) ∗ d 2 + ( n − 3 ) ∗ d 3 + . . . + d n − 1 (n-1) * d_1 + (n-2) * d_2 + (n-3) * d_3 + ...+ d_{n-1} (n−1)∗d1+(n−2)∗d2+(n−3)∗d3+...+dn−1 模 x x x的余数相同。
到这里就转化成了组合问题。
下面就可以用闫氏dp分析法了。
1.状态表示:
f[i][j]
表示要选i
个a
或者-b
且余数为j
的所有集合的数量。
2.状态计算:第i
个可以选a
或者-b
。第
i
个选a
: ( n − 1 ) ∗ d 1 + ( n − 2 ) ∗ d 2 + ( n − 3 ) ∗ d 3 + . . . + 2 ∗ d n − 2 + a (n-1) * d_1 + (n-2) * d_2 + (n-3) * d_3 +...+ 2 * d_{n-2}+ a (n−1)∗d1+(n−2)∗d2+(n−3)∗d3+...+2∗dn−2+a 模 x = j x = j x=j。则: ( n − 1 ) ∗ d 1 + ( n − 2 ) ∗ d 2 + ( n − 3 ) ∗ d 3 + . . . + 2 ∗ d n − 2 (n-1) * d_1 + (n-2) * d_2 + (n-3) * d_3 +...+ 2 * d_{n-2} (n−1)∗d1+(n−2)∗d2+(n−3)∗d3+...+2∗dn−2 模 x = j − a x = j - a x=j−a。
系数和下标之和为
n
,所以第i
项的的系数为n-i
。所以:
f[i][j] = f[i - 1][j - (n - i) * a]
第i个选b:同理:
f[i][j] = f[i - 1][j + (n - i) * b]
时间复杂度 O ( n 2 ) O(n^2) O(n2)
状态数量 * 状态转移时操作数 = (n - 1) * (n - 1) * 2
C++代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 1010, MOD = 100000007;
int n, s, a, b;
int f[N][N];
int get_mod(int a, int b)
{
return (a % b + b) % b;
}
int main()
{
scanf("%d%d%d%d", &n, &s, &a, &b);
f[0][0] = 1;
for (int i = 1; i < n; i++)
for (int j = 0; j < n; j++)
f[i][j] = (f[i - 1][get_mod(j - (n - i) * a, n)] + f[i - 1][get_mod(j + (n - i) * b, n)]) % MOD;
printf("%d", f[n - 1][get_mod(s, n)]);
return 0;
}
Java代码
import java.util.Scanner;
public class Main {
static int N = 1010;
static int MOD = 100000007;
static int[][] f = new int[N][N];
public static void main(String[] args) {
Scanner scan = new Scanner(System.in);
int n, s, a, b;
n = scan.nextInt();
s = scan.nextInt();
a = scan.nextInt();
b = scan.nextInt();
f[0][0] = 1;
for (int i = 1; i < n; i++)
for (int j = 0; j < n; j++)
f[i][j] = (f[i - 1][getMod(j - (n - i) * a, n)] + f[i - 1][getMod(j + (n - i) * b, n)]) % MOD;
System.out.println(f[n - 1][getMod(s, n)]);
}
public static int getMod(int a, int b) {
return (a % b + b) % b;
}
}