题目大意: Ant和Bob要占领一个国家,这个国家有N个相互连接的城市(编号为1-N),任意两个城市间有且只有一条线路,占领城市需要花费一定的时间。占领顺序任意,但是若Ant占领了一个城市,则他占领与这个城市相连的城市时需要花费的时间会减半,Bob也是一样。现在给出Ant和Bob占领每个城市需要的时间以及城市之间的同路,求Ant和Bob占领整个国家需要的时间。
输入:第一行 N代表有几个城市,第二行和第三行每行都有N个数,代表Ant和Bob占领各个城市需要的时间,接下来N行每行都有两个数a,b,代表a和b城市间有一条路。
输出:占领所有城市需要的最少时间。
样例:
输入: 3
1 2 5
3 8 1
1 2
1 3
输出: 3
解法:DP,三维DP DP[s][i][j]
先将地图构造成一个有根树
j=0代表由Ant占领s城市,j=1代表由Bob占领s城市。
i=0代表s城市是在与他相连的城市里第一个被占领的;i=1代表s城市的父亲已经被相同的人占领
i=2代表在占领s之前s的孩子已经有一个被相同的人占领了。
状态转移:若s是叶子节点,则 dp[s][0][0]=Ant[s],dp[s][1][0]=Ant[s]/2,dp[s][2][0]=INF;(Bob也一样)
否则dp[s][0][0]=Sigma(min(dp[son[j]][1][0],min(dp[son[j]][0][1],dp[son[j]][2][1]))+Ant[s]
dp[s][1][0]=dp[s][0][0]-Ant[s]+Ant[s]/2;
dp[s][2][0]选一个孩子作为比s先占领的(最优情况下s的孩子中有且只有一个是最先被Ant占领)
然后枚举其他孩子 dp[son][1][0],因为s已经被占领
代码:
#include <iostream>
#include <vector>
#include <memory.h>
#include <cstdio>
using namespace std;
const int INF = 100000000;
vector<int> edge[103];
int dp[103][3][2],par[103],antVal[103],bobVal[103];
void init()
{
memset(dp,0,sizeof(dp));
memset(par,0,sizeof(par));
for(int i=0;i<102;i++) edge[i].clear();
}
void dfs(int curr)
{
int tmp,t2,minVal,val;
if(edge[curr].size()==1)
{
dp[curr][0][0] = antVal[curr];
dp[curr][0][1] = bobVal[curr];
dp[curr][1][0] = antVal[curr]/2;
dp[curr][1][1] = bobVal[curr]/2;
dp[curr][2][0] = INF;
dp[curr][2][1] = INF;
return;
}
dp[curr][0][0] = antVal[curr];
dp[curr][2][0] = antVal[curr]/2;
if(par[curr]!=0) dp[curr][1][0] = antVal[curr]/2;
else dp[curr][1][0] = INF;
dp[curr][0][1] = bobVal[curr];
if(par[curr]!=0) dp[curr][1][1] = bobVal[curr]/2;
else dp[curr][1][1] = INF;
dp[curr][2][1] = bobVal[curr]/2;
for(int i=0;i<edge[curr].size();i++)
{
if(edge[curr][i]!=par[curr])
{
par[edge[curr][i]] = curr;
dfs(edge[curr][i]);
}
}
for(int i=0;i<edge[curr].size();i++)
{
tmp = edge[curr][i];
if(tmp!=par[curr])
{
dp[curr][0][0] += min(dp[tmp][1][0],min(dp[tmp][0][1],dp[tmp][2][1]));
dp[curr][0][1] += min(dp[tmp][1][1],min(dp[tmp][0][0],dp[tmp][2][0]));
dp[curr][1][0] += min(dp[tmp][1][0],min(dp[tmp][0][1],dp[tmp][2][1]));
dp[curr][1][1] += min(dp[tmp][1][1],min(dp[tmp][0][0],dp[tmp][2][0]));
}
}
minVal = INF;
for(int i=0;i<edge[curr].size();i++)
{
if(edge[curr][i]!=par[curr])
{
t2 = edge[curr][i];
val = min(dp[t2][2][0],dp[t2][0][0]);
for(int j=0;j<edge[curr].size();j++)
{
if(edge[curr][j]!=par[curr]&&j!=i)
{
tmp = edge[curr][j];
val += min(dp[tmp][1][0],min(dp[tmp][0][1],dp[tmp][2][1]));
}
}
if(val<minVal) minVal=val;
}
}
dp[curr][2][0] += minVal;
minVal = INF;
for(int i=0;i<edge[curr].size();i++)
{
if(edge[curr][i]!=par[curr])
{
t2 = edge[curr][i];
val = min(dp[t2][0][1],dp[t2][2][1]);
for(int j=0;j<edge[curr].size();j++)
{
if(edge[curr][j]!=par[curr]&&j!=i)
{
tmp = edge[curr][j];
val += min(dp[tmp][1][1],min(dp[tmp][0][0],dp[tmp][2][0]));
}
}
if(val<minVal) minVal=val;
}
}
dp[curr][2][1] += minVal;
}
int main()
{
int n,a,b,res;
while(scanf("%d",&n)!=EOF)
{
init();
for(int i=1;i<=n;i++) scanf("%d",&antVal[i]);
for(int i=1;i<=n;i++) scanf("%d",&bobVal[i]);
for(int i=1;i<n;i++)
{
scanf("%d%d",&a,&b);
edge[a].push_back(b);
edge[b].push_back(a);
}
edge[1].push_back(0);
dfs(1);
res = dp[1][0][0];
if(dp[1][0][1]<res) res=dp[1][0][1];
if(dp[1][2][1]<res) res=dp[1][2][1];
if(dp[1][2][0]<res) res=dp[1][2][0];
printf("%d\n",res);
}
return 0;
}