题意:略(因为解释起来比较困难)
思路:利用dfs记录边然后反向将边标记出来,记录每条边的出现次数,然后利用dp进行转移,由于边数之和固定,所以可以去掉一维,然后再用滚动数组优化
原方程式
i f ( f [ i ] [ j ] [ k ] ! = 0 ) if(f[i][j][k]!=0) if(f[i][j][k]!=0)
f [ i + 1 ] [ j + v a l [ i ] ] [ k ] + = f [ i ] [ j ] [ k ] f[i+1][j+val[i]][k]+=f[i][j][k] f[i+1][j+val[i]][k]+=f[i][j][k]
f [ i + 1 ] [ j ] [ k + v a l [ i ] ] + = f [ i ] [ j ] [ k ] f[i+1][j][k+val[i]]+=f[i][j][k] f[i+1][j][k+val[i]]+=f[i][j][k]
优化结果如下:
i f ( f [ i ] [ j ] ! = 0 ) if(f[i][j]!=0) if(f[i][j]!=0)
f [ i + 1 ] [ j + v a l [ i ] ] + = f [ i ] [ j ] f[i+1][j+val[i]]+=f[i][j] f[i+1][j+val[i]]+=f[i][j]
f [ i + 1 ] [ j ] + = f [ i ] [ j ] f[i+1][j]+=f[i][j] f[i+1][j]+=f[i][j]
由于滚动数组的缘故转移的时候要
f [ i & 1 ] [ 1..... s u m ] = 0 f[i\&1][1.....sum]=0 f[i&1][1.....sum]=0
初始化 f [ 0 ] [ 0 ] = 1 f[0][0]=1 f[0][0]=1
目标
f [ n − 1 ] [ x ] , 2 ∗ x = s u m + k f[n-1][x],2*x=sum+k f[n−1][x],2∗x=sum+k
#pragma GCC optimize(2)
#pragma GCC optimize(3,"Ofast","inline")
#include <bits/stdc++.h>
#define inf 0x7fffffff
//#define ll long long
#define int long long
//#define double long double
#define re register int
#define void inline void
#define eps 1e-8
//#define mod 1e9+7
#define ls(p) p<<1
#define rs(p) p<<1|1
#define pi acos(-1.0)
#define pb push_back
#define P pair < int , int >
#define mk make_pair
using namespace std;
const int mod=998244353;
const int M=5e6;
const int N=3e6+5;//?????????? 4e8
struct node
{
int ver,next;
}e[N];
int tot=1,head[N],pre[N];
int d[N],v[N];
int n,m,k,sum,a[N],ans;
int f[3][2000005];
void add(int x,int y)
{
e[++tot].ver=y;
e[tot].next=head[x];
head[x]=tot;
}
void addedge(int x,int y)
{
add(x,y);add(y,x);
}
void dfs1(int x,int fa)
{
for(re i=head[x];i;i=e[i].next)
{
int y=e[i].ver;
if(y==fa) continue;
pre[y]=i;
dfs1(y,x);
}
}
void print(int i,int end)
{
int x=e[i^1].ver;
v[i/2]++;
if(x==end) return;
print(pre[x],end);
}
void solve()
{
cin>>n>>m>>k;
for(re i=1;i<=m;i++) scanf("%lld",&a[i]);
for(re i=1;i<n;i++)
{
int x,y;
scanf("%lld%lld",&x,&y);
addedge(x,y);
}
for(re i=1;i<m;i++)
{
for(re i=1;i<=n;i++) pre[i]=0;
if(a[i]==a[i+1]) continue;
dfs1(a[i],a[i]);
print(pre[a[i+1]],a[i]);
}
for(re i=1;i<n;i++) sum+=v[i];
f[0][0]=1;
for(re i=0;i<n-1;i++)
{
for(re j=0;j<=sum;j++) if(f[i&1][j])
{
f[(i+1)&1][j]=(f[(i+1)&1][j]+f[i&1][j])%mod;
f[(i+1)&1][j+v[i+1]]=(f[(i+1)&1][j+v[i+1]]+f[i&1][j])%mod;
}
for(re j=0;j<=sum;j++) f[i&1][j]=0;
// memset(f[i&1],0,sizeof(f[i&1]));
}
for(re i=0;i<=sum;i++) if(2*i==sum+k) ans=f[(n-1)&1][i]%mod;
cout<<ans<<endl;
}
signed main()
{
int T=1;
// cin>>T;
for(int index=1;index<=T;index++)
{
solve();
// puts("");
}
return 0;
}
/*
#((((#
*/