spoj 1825
题目:http://www.spoj.com/problems/FTOUR2/
题目大意:给你一棵节点数为 N 的树,每条树枝有权值,点有黑白两色,问你找一条路径使其进过的黑色的节点数不超过 K 且权值和最大,然后输出这个权值。
思路:继上题的 Tree 之后,漆子超论文的下一道题目,表示看论文、题解和别人代码做了很久。。 = =
用G[ i ][ j ] 表示根节点 的第 i 个儿子经过的黑点数为 j 的最优值,但是 i、j 范围太大了,空间开不下。但是我们不需要保存所有的儿子对应的所有 j ,我们只关心已经算过的节点中每个 j 对应的最大值,所以这里需要优化一下:用f [ i ] 表示已处理的节点中黑点数不超过 i 的最优值,显然,f 具有单调递增性。然后对于当前要处理的节点,算出 g[ i ] ,,然后用g[ i ] 去和 f[ i ]组合更新 ans ,再更新 f[ i ] 就行了,注意:f[ i ] 更新好以后,也要用符合要求的 i 更新 ans ,因为 g[ i ] 有可能不和 f[ i ] 结合,即,根节点为路径的起点或终点。对于每一个根节点,需要对每个儿子按照 dep[ i ] 先进行排序,然后每次只要更新到dep[ i ] 就行了,时间复杂度为排序的复杂度O(NlogN),如果不排序,则最坏情况下,时间复杂度会达到 O(N^2 )。
基本上是照着别人代码来的,先开始一直TLE,然后WA,然后RE,最后才 AC 的。。 = = WA的原因在于,我把 solve()里的 getted[ root ] =0, 写成 getted[ x ] =0 了。。 T T,而TLE是在于找 root 时,我是按照上一题的方法来的,代码是这样的:
int num[MAXN],maxv[MAXN];
vector <int> node;
void dfs(int u,int fa)
{
node.push_back(u);
num[u]=1;
for(int e = head[u];e!=-1;e= edge[e].next)
{
int v = edge[e].t;
if(getted[v]||v==fa) continue;
dfs(v,u);
num[u]+=num[v];
maxv[u] = max(maxv[u],num[v]);
}
}
int get_root(int x)
{
node.clear();
dfs(x,0);
int minn = INF;
int sum_node = num[x];
int root;
for(int i=0;i<node.size();i++)
{
int cur = node[i];
maxv[cur] = max(maxv[cur],sum_node-num[cur]);
if(maxv[cur]<minn)
{
minn = maxv[cur];
root = cur;
}
}
return root;
}
然后看了别人的,改成下面这个就过了:
void dfs1(int u,int fa)
{
father[u]=fa;
num[u]=1;
maxv[u]=0;
for(int e = head[u];e!=-1;e= edge[e].next)
{
int v = edge[e].t;
if(getted[v]||v==fa) continue;
dfs1(v,u);
num[u]+=num[v];
maxv[u] = max(maxv[u],num[v]);
}
}
int minn;
void dfs2(int u,int sum,int& root)
{
for(int e = head[u] ; e!=-1 ; e =edge[e].next)
{
int v = edge[e].t;
if(getted[v]||v==father[u]) continue;
dfs2(v,sum,root);
}
int tmp = max(sum-num[u],maxv[u]);
if(tmp<minn)
{
minn = tmp;
root = u;
}
}
int get_root(int x)
{
dfs1(x,0);
minn = INF;
int sum_node = num[x];
int root;
dfs2(x,sum_node,root);
return root;
}
这复杂度不是一样的嘛,想不清楚。。= =
还有,如果把上面那段代码里 dfs2()里和 dfs1()一样加个 fa ,把 dfs1()里的 father 数组去掉,交上去,竟然是RE。。 想不明白啊,想不明白。。
好吧,改来改去,终于AC了,代码如下:
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
const int INF = 0x0fffffff ;
const int MAXN = 400022 ;
int n,m,k;
struct Edge
{
int t,next,len;
} edge[MAXN<<1];
int head[MAXN],tot;
void add_edge(int s,int t,int len)
{
edge[tot].len=len;
edge[tot].t=t;
edge[tot].next = head[s];
head[s] = tot++;
}
bool is_black[MAXN],getted[MAXN];
int num[MAXN],maxv[MAXN];
int father[MAXN];
void dfs1(int u,int fa)
{
father[u]=fa;
num[u]=1;
maxv[u]=0;
for(int e = head[u];e!=-1;e= edge[e].next)
{
int v = edge[e].t;
if(getted[v]||v==fa) continue;
dfs1(v,u);
num[u]+=num[v];
maxv[u] = max(maxv[u],num[v]);
}
}
int minn;
void dfs2(int u,int sum,int& root)
{
for(int e = head[u] ; e!=-1 ; e =edge[e].next)
{
int v = edge[e].t;
if(getted[v]||v==father[u]) continue;
dfs2(v,sum,root);
}
int tmp = max(sum-num[u],maxv[u]);
if(tmp<minn)
{
minn = tmp;
root = u;
}
}
int get_root(int x)
{
dfs1(x,0);
minn = INF;
int sum_node = num[x];
int root;
dfs2(x,sum_node,root);
return root;
}
int dep[MAXN];
void get_dep(int u,int fa)
{
dep[u] = is_black[u];
for(int e = head[u];e!=-1;e = edge[e].next)
{
int v = edge[e].t;
if(getted[v]||v==fa) continue;
get_dep(v,u);
dep[u] = max(dep[u],is_black[u]+dep[v]);
}
}
int g[MAXN];
void get_g(int u,int fa,int s,int c)
{
g[c] = max(g[c],s);
for(int e = head[u] ; e!=-1 ;e=edge[e].next)
{
int v = edge[e].t;
int len = edge[e].len;
if(getted[v]||v==fa) continue;
get_g(v,u,s+len,c+is_black[v]);
}
}
int id[MAXN];
bool cmp(int a,int b)
{
return dep[edge[a].t]<dep[edge[b].t];
}
int ans;
int f[MAXN];
void solve(int x)
{
int root = get_root(x);
//printf("root = %d\n",root);
getted[root]=1;
for(int e = head[root]; e!=-1 ;e=edge[e].next)
{
int v = edge[e].t;
if(getted[v]) continue;
solve(v);
}
int cc=0;
for(int e = head[root];e!=-1; e =edge[e].next)
{
int v = edge[e].t;
if(getted[v]) continue;
get_dep(v,root);
id[cc++]=e;
}
sort(id,id+cc,cmp);
for(int i=0;i<=dep[edge[id[cc-1]].t];i++)
f[i]=-INF;
//printf("root = %d\n",root);
for(int i=0;i<cc;i++)
{
int cur = edge[id[i]].t;
int len = edge[id[i]].len;
for(int j=0;j<=dep[cur];j++)
g[j]=-INF;
get_g(cur,root,len,is_black[cur]);
//printf("cur = %d\n",cur);
if(i>0)
{
int end = min(k - is_black[root],dep[cur]);
for(int j=0;j<=end;j++)
{
int p = min(dep[edge[id[i-1]].t],k-j-is_black[root]);
//printf("g[%d] = %d,f[%d] = %d\n",j,g[j],p,f[p]);
if(f[p]==-INF) break;
if(g[j]!=-INF)
ans=max(ans,g[j]+f[p]);
}
}
for(int j=0;j<=dep[cur];j++)
{
f[j]=max(f[j],g[j]);
if(j>0) f[j]=max(f[j],f[j-1]);
if(i==0&&j+is_black[root]<=k) ans = max(ans,f[j]);
}
}
//printf("ans = %d\n",ans);
getted[root]=0;
}
int main()
{
while(~scanf("%d%d%d",&n,&k,&m))
{
memset(is_black,0,sizeof(is_black));
int pos;
for(int i=0;i<m;i++)
{
scanf("%d",&pos);
is_black[pos]=1;
}
memset(head,-1,sizeof(head));
tot=0;
int a,b,c;
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&a,&b,&c);
add_edge(a,b,c);
add_edge(b,a,c);
}
ans=0;
memset(getted,0,sizeof(getted));
solve(1);
printf("%d\n",ans);
}
return 0;
}
/*
7 5 6
2
3
4
5
6
7
1 7 100
1 5 100
5 6 100
1 2 1
2 3 1
3 4 1
*/