题目描述:
n个人,k个相同的人不能连续坐一起,n个人只能是A或者B.旋转不算数.求方案数.
题解:
首先旋转我们可以用波利亚计数.之后变成了:安排x个人,x人内部不超过k,而且x的头和尾如果一样,之和也不能超过k.对于这种,其实是n^2的空间的dp.强制A结尾,并且记录A结尾的A的个数.dp_a[i][j],A开头,i为长度,结尾j个A.dp_b[i][j],B开头,i为长度,结尾j个B.先dp出来这个,然后再导出来dp.注意很多细节,比如dp_a中的AAAA的情况.
重点:
1.波利亚计数
2.头尾也需要考虑的方法.
代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <ctype.h>
#include <limits.h>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <queue>
#include <map>
#include <stack>
#include <set>
#include <bitset>
#define CLR(a) memset(a, 0, sizeof(a))
#define REP(i, a, b) for(ll i = a;i < b;i++)
#define REP_D(i, a, b) for(ll i = a;i <= b;i++)
typedef long long ll;
using namespace std;
const ll maxn = 1000 + 10;
const ll MOD = 1000003;
ll dp[maxn], dp_a[maxn][maxn], dp_b[maxn][maxn];
ll sum_a[maxn][maxn], sum_b[maxn][maxn];
ll n, k;
ll pow_mod(ll x, ll n)
{
if(n==0)
{
ll t = 1;
return t;
}
x %= MOD;
ll xx = (x*x)%MOD;
ll nn = n/2;
ll res = pow_mod(xx, nn);
if(n%2==1)
{
res = (res*x)%MOD;
}
return res;
}
ll gcd(ll x, ll y)
{
if(y==0)
{
return x;
}
return gcd(y, x%y);
}
void getDp()
{
CLR(dp_a);
CLR(dp_b);
CLR(sum_a);
CLR(sum_b);
dp_a[1][1] = 1;
sum_a[1][1] = 1;
for(ll i = 2; i <= n; i++)
{
if(i<=k)
{
dp_a[i][i] = 1;
}
ll key = min(i-2, k);
for(ll j = 1; j<=key; j++)
{
ll limit = i - k - 1;
limit = max(0LL, limit);
dp_a[i][j] = ((sum_b[i-1][j] - sum_b[limit][j])%MOD + MOD)%MOD;
//dp_b[i][j] = ((sum_a[i-1][j] - sum_a[limit][j])%MOD + MOD)%MOD;
//printf("i is %lld j is %lld dp_b is %lld\n", i, j, dp_b[i][j]);
}
key = min(i-1, k);
for(ll j = 1; j<=key; j++)
{
ll limit = i - k - 1;
limit = max(0LL, limit);
dp_b[i][j] = ((sum_a[i-1][j] - sum_a[limit][j])%MOD + MOD)%MOD;
//printf("i is %lld j is %lld dp_b is %lld\n", i, j, dp_b[i][j]);
}
for(ll j = 1; j<=i; j++)
{
sum_a[i][j] = (sum_a[i-1][j] + dp_a[i][j])%MOD;
sum_b[i][j] = (sum_b[i-1][j] + dp_b[i][j])%MOD;
}
}
CLR(sum_b);
for(ll i = 2;i <= n;i++)
{
//sum_b[i][0] = 0;
for(ll j = 1;j <= n;j++)
{
sum_b[i][j] = (sum_b[i][j-1]+dp_b[i][j])%MOD;
}
}
CLR(dp);
for(ll i = 1; i <= n; i++)
{
for(ll j= 1;j<=min(k, i-1);j++)
dp[i] = (dp[i]+dp_b[i][j])%MOD;
// if(i<=k)
// {
// dp[i] = (dp[i]+1)%MOD;
// }
for(ll j = 1; j <= min(k,i - 2); j++)
{
ll lft = k - j;
lft = min(lft, n);
dp[i] = (dp[i]+sum_b[i-j][lft])%MOD;
}
// if(i <= k)
// {
// dp[i] = (dp[i]+1)%MOD;
// }
//printf(" i is %lld %lld\n", i, dp[i]);
dp[i] = (2*dp[i])%MOD;
}
}
void solve()
{
int all = 0;
if(k >= n)
{
all = 2;
//k = n - 1;
}
getDp();
ll ans = dp[n];
for(ll i = 1;i<n;i++)
{
ll t = gcd(i, n);
ans = (ans + dp[t])%MOD;
}
ans = (ans*pow_mod(n, MOD-2))%MOD;
ans = (ans + all)%MOD;
printf("%lld\n", ans);
}
int main()
{
//freopen("8Hin.txt", "r", stdin);
//freopen("8Hout.txt", "w", stdout);
while(scanf("%lld%lld", &n, &k) != EOF)
{
if(!n && !k)
break;
solve();
}
return 0;
}