题解:
考虑枚举哪几个水果是甜的。
发现如果枚举的数量一样,那么对答案的贡献也一样。那么可以根据枚举数量来一起处理。
枚举完数量 k k 之后,甜与半甜的不连边,然后跑矩阵树,得到小等于的甜的水果的生成树数量。 因为之前的已经处理过,直接减掉就好。 注意减去的时候要减去组合数,因为这k个中任意组合都会被统计。
之后问题变为了数量 k k , 值小等于的个数,这个直接折半就行了。
const int mod=1e9+7;
int n,c0,mxv,ans,sweet[50],tr[50],cnt[50],A[50][50],c[50][50];
vector <int> a0;
vector <int> qry[50];
vector <int> chk[50];
inline int power(int a,int b) {
int rs=1;
for(;b;b>>=1,a=(long long)a*a%mod) if(b&1) rs=(long long) rs*a%mod;
return rs;
}
class SweetFruits {
public:
inline void dfs1(int now,int LIM,int cnt,int sum) {
if(sum>mxv) return;
if(now==LIM+1) {qry[cnt].push_back(sum); return;}
dfs1(now+1,LIM,cnt,sum);
dfs1(now+1,LIM,cnt+1,sum+a0[now]);
}
inline void dfs2(int now,int LIM,int cnt,int sum) {
if(sum>mxv) return;
if(now==LIM+1) {chk[cnt].push_back(sum); return;}
dfs2(now+1,LIM,cnt,sum);
dfs2(now+1,LIM,cnt+1,sum+a0[now]);
}
inline int calc2(int k) {
if(!c0) return k?0:1;
if(!k) {
int mid=a0.size()/2;
dfs1(0,mid,0,0);
for(int i=0;i<=n;i++) sort(qry[i].begin(),qry[i].end());
dfs2(mid+1,a0.size()-1,0,0);
for(int i=0;i<=n;i++) sort(chk[i].begin(),chk[i].end());
}
int ans=0;
for(int i=0;i<=k;i++) {
if(!chk[i].size()) continue;
int p2=qry[k-i].size()-1;
for(int p1=0;p1<chk[i].size();++p1) {
while(~p2 && chk[i][p1]+qry[k-i][p2]>mxv) --p2;
if(p2==-1) break;
ans=(ans+p2+1)%mod;
}
}
return ans;
}
inline int det() {
int sgn=1;
for(int i=1;i<n;i++)
for(int j=1;j<n;j++)
if(A[i][j]<0) A[i][j]=mod+A[i][j];
for(int i=1;i<n;i++) {
int l=i;
for(int j=i+1;j<n;j++)
if(A[j][i]>A[i][i]) l=j;
if(l!=i) {
for(int j=i;j<n;j++) swap(A[l][j],A[i][j]);
sgn=-sgn;
}
if(A[i][i]==0) return 0;
for(int j=i+1;j<n;j++) {
int t=(long long)A[j][i]*power(A[i][i],mod-2)%mod;
for(int k=i;k<n;k++)
A[j][k]=(A[j][k]-(long long)A[i][k]*t%mod+mod)%mod;
}
}
int rs=1;
for(int i=1;i<n;i++)
rs=(long long)rs*A[i][i]%mod;
return sgn>0?rs:mod-rs;
}
inline int calc1(int nn) {
memset(A,0,sizeof(A));
for(int i=1;i<=nn;i++) {
for(int j=i+1;j<=nn;j++)
--A[i][j],--A[j][i],++A[i][i],++A[j][j];
for(int j=c0+1;j<=n;j++)
--A[i][j],--A[j][i],++A[i][i],++A[j][j];
}
for(int i=nn+1;i<=c0;i++)
for(int j=c0+1;j<=n;j++)
--A[i][j],--A[j][i],++A[i][i],++A[j][j];
for(int i=c0+1;i<=n;i++)
for(int j=i+1;j<=n;j++)
--A[i][j],--A[j][i],++A[i][i],++A[j][j];
int k=det();
for(int i=0;i<nn;i++) k=(k-(long long)c[nn][i]*cnt[i]%mod+mod)%mod;
return cnt[nn]=k;
}
int countTrees(vector<int> st, int mx) {
n=st.size(); mxv=mx;
for(int i=0;i<=n;i++) c[i][0]=1;
for(int i=1;i<=n;i++)
for(int j=1;j<=i;j++)
c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
for(int i=0;i<n;i++)
if(st[i]!=-1) ++c0,a0.push_back(st[i]);
for(int i=0;i<=c0;i++)
ans=(ans+(long long)calc1(i)*calc2(i)%mod)%mod;
return ans;
}
};