题意:有 n n n 个学校隶属于 c c c 个城市,每个学校有 s i s_i si 个人。把它们放入一个 2 × 2 2\times 2 2×2 的格子中,要求同一学校的必须放在同一个格子,同一城市的必须放在同一行,并给出两行两列分别最多能放的人数 C 0 , C 1 , D 0 , D 1 C_0,C_1,D_0,D_1 C0,C1,D0,D1。此外还有 k k k 条限制,形如某个学校的人不能放入某个特定的格子中,每个学校最多只有一种限制。求方案数 模 998244353 998244353 998244353。
n , c ≤ 1 0 3 , C 0 , C 1 , D 0 , D 1 ≤ 2500 , k ≤ 30 , s i ≤ 10 n,c\leq 10^3,C_0,C_1,D_0,D_1\leq 2500,k\leq 30,s_i\leq 10 n,c≤103,C0,C1,D0,D1≤2500,k≤30,si≤10,数据组数 T = 5 T=5 T=5
首先考虑暴力 dp,发现确定第一行、第一列后完整方案就确定了。所以可以设 f ( x , y ) f(x,y) f(x,y) 表示第一行放了 x x x 个数,第一列放了 y y y 个数,直接转移即可。计算答案时 x x x 和 y y y 的上下界都可以确定出来,这样就有 50pts 的好成绩。
然后发现对于没有限制的情况,行和列是分别独立的,所以分别设 A ( x ) , B ( x ) A(x),B(x) A(x),B(x) 为第一行、第一列选了 x x x 个的方案数,前者用城市总人数转移,后者用学校人数转移,然后乘起来就可以了。单独写就有 50pts,结合前面可以拿到 70pts。
考虑这个方法不好扩展,但发现有限制的城市不超过 k k k 个,所以可以从“降低了数据范围的方向”来考虑。
开始想的是 dp 的时候记录两列受限制的人数,然而状态数爆炸。
发现城市数很少并没什么用,因为一个城市里面即使只有一个受限制的学校,你整个城市都只能大暴力,
智商分割线
我们称一个城市受限当且仅当至少一个该城市的学校受限。考虑继续挖掘受限的城市中不受限的学校的性质。
冷静分析可以发现,对于一个受限的城市,确定了受限的学校选择了哪一行后,不受限的学校选哪一行就确定了,只需要决定选哪一列然后根据最终方案随机应变就可以了。也就是说这部分拿去更新 B ( x ) B(x) B(x) 就可以不管了。
对于最终受限的学校只有 30 30 30 个,直接用第一个方法大暴力就可以了。要注意的是需要先决定每一个城市选了哪一行,可以对每个城市把状态分为两部分,分别为选择第一行和第二行的 dp 值。然后把两部分直接加起来,在这里把第一行该城市的人数减掉。
注意 dp 顺序,必须依次枚举城市更新所有信息,否则会算重。
然后把这个暴力 dp 求二维前缀和,枚举前面的 A ( x ) , B ( x ) A(x),B(x) A(x),B(x) 算贡献即可。
显然复杂度瓶颈在第三步,复杂度为 O ( k M 2 ) O(kM^2) O(kM2),非常卡。
注意到这个 dp 是从头开始算的,第二维不会超过 k s i ks_i ksi,就可以优化到 O ( k 2 s i M ) O(k^2s_iM) O(k2siM)
总复杂度 O ( T M ( N + k 2 s i ) ) O(TM(N+k^2s_i)) O(TM(N+k2si))
有点卡常,不过不用 memset 应该问题不大
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#define re register
using namespace std;
typedef long long ll;
const int MOD=998244353;
inline int mod(int x){return x+=(x>>31)&MOD;}
inline int add(const int& x,const int& y){return mod(x+y-MOD);}
inline int dec(const int& x,const int& y){return mod(x-y);}
inline int read()
{
int ans=0;
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
vector<int> lis[1005],fr,hfr;
int n,c,k,C[2],D[2],b[1005],s[1005],p[1005],cnt[1005],M;
int A[2505],B[2505],g[2505],f[2505][2505],t0[2505][2505],t1[2505][2505],tmp0[2505][2505],tmp1[2505][2505],siz[1005];
inline void clear()
{
for (int i=1;i<=c;i++) lis[i].clear();
memset(p,-1,sizeof(p));
memset(cnt,0,sizeof(cnt));
fr.clear(),hfr.clear();
memset(A,0,sizeof(A));
memset(B,0,sizeof(B));
memset(g,0,sizeof(g));
memset(f,0,sizeof(f));
memset(siz,0,sizeof(siz));
A[0]=B[0]=g[0]=f[0][0]=1;
}
int main()
{
freopen("test.in","r",stdin);
for (int T=read();T;T--)
{
n=read(),c=read(),C[0]=read(),C[1]=read(),D[0]=read(),D[1]=read();
M=max(max(C[0],C[1]),max(D[0],D[1]));
clear();
int tot=0;
for (int i=1;i<=n;i++) b[i]=read(),s[i]=read(),siz[b[i]]+=s[i],tot+=s[i];
k=read();
int MAX=min(M,k*10);
while (k--)
{
int i=read();
p[i]=read();
++cnt[b[i]];
}
if (C[0]+C[1]<tot||D[0]+D[1]<tot)
{
puts("0");
continue;
}
for (int i=1;i<=n;i++)
if (!cnt[b[i]]) fr.push_back(i);
else if (p[i]==-1) hfr.push_back(i);
else lis[b[i]].push_back(i);
for (int i=1;i<=c;i++)
if (!cnt[i]&&siz[i])
for (int j=M;j>=siz[i];j--)
A[j]=add(A[j],A[j-siz[i]]);
for (int i=0;i<(int)fr.size();i++)
{
int v=s[fr[i]];
for (int j=M;j>=v;j--) B[j]=add(B[j],B[j-v]);
}
for (int i=0;i<(int)hfr.size();i++)
{
int v=s[hfr[i]];
for (int j=M;j>=v;j--) B[j]=add(B[j],B[j-v]);
}
for (re int i=1;i<=c;i++)
{
if (!cnt[i]) continue;
for (re int j=0;j<=M;j++)
for (re int k=0;k<=MAX;k++)
t0[j][k]=t1[j][k]=f[j][k];
for (re int l=0;l<cnt[i];l++)
{
re int v=s[lis[i][l]],lim=p[lis[i][l]];
for (re int j=0;j<=M;j++)
for (re int k=0;k<=MAX;k++)
tmp0[j][k]=tmp1[j][k]=0;
for (re int j=0;j<=M;j++)
for (re int k=0;k<=MAX;k++)
{
if (lim!=0&&k>=v) tmp0[j][k]=add(tmp0[j][k],t0[j][k-v]);
if (lim!=1) tmp0[j][k]=add(tmp0[j][k],t0[j][k]);
if (lim!=2&&k>=v) tmp1[j][k]=add(tmp1[j][k],t1[j][k-v]);
if (lim!=3) tmp1[j][k]=add(tmp1[j][k],t1[j][k]);
}
for (re int j=0;j<=M;j++)
for (re int k=0;k<=MAX;k++)
t0[j][k]=tmp0[j][k],t1[j][k]=tmp1[j][k];
}
for (int j=0;j<=M;j++)
for (int k=0;k<=MAX;k++)
f[j][k]=add(j>=siz[i]? t0[j-siz[i]][k]:0,t1[j][k]);
}
for (int i=0;i<=M;i++)
for (int j=0;j<=M;j++)
{
if (i) f[i][j]=add(f[i][j],f[i-1][j]);
if (j) f[i][j]=add(f[i][j],f[i][j-1]);
if (i&&j) f[i][j]=dec(f[i][j],f[i-1][j-1]);
}
int res=0;
for (int i=0;i<=C[0];i++)
for (int j=0;j<=D[0];j++)
{
if (!A[i]||!B[j]) continue;
int xl=max(0,tot-C[1]-i),xr=C[0]-i;
int yl=max(0,tot-D[1]-j),yr=D[0]-j;
int ans=f[xr][yr];
if (xl) ans=dec(ans,f[xl-1][yr]);
if (yl) ans=dec(ans,f[xr][yl-1]);
if (xl&&yl) ans=add(ans,f[xl-1][yl-1]);
res=(res+(ll)ans*A[i]%MOD*B[j])%MOD;
}
printf("%d\n",res);
}
return 0;
}