真的是诸事不顺,先被k题卡了题意。真水题被卡题意可还行。。。
然后,,就没有然后了。
上手这个题的时候还有半个小时,看起来挺水的一个树形dp,两遍dfs搞定。
可能是因为当时感觉时间不够了,有点紧张,在第一遍dfs的时候忘了加子节点的距离和。
加上就过了。
题目链接:点击这里
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<set>
#include<deque>
#include<map>
#include<vector>
#include<cmath>
#define ll long long
#define llu unsigned ll
#define pr pair<int,int>
using namespace std;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const int inf = 0x3f3f3f3f;
const int maxn=20100;
const int mod=1e9+7;
int head[maxn],ver[maxn],nt[maxn];
ll edge[maxn];
ll dp1[maxn][3],dp2[maxn][3];
ll cnt1[maxn][3],cnt2[maxn][3];
int tot=1;
void add(int x,int y,ll z)
{
ver[++tot]=y,edge[tot]=z;
nt[tot]=head[x],head[x]=tot;
}
void dfs1(int x,int fa)
{
cnt1[x][0]=1;
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
ll z=edge[i];
if(y==fa) continue;
dfs1(y,x);
for(int j=0;j<3;j++)
{
cnt1[x][(j+z)%3]=(cnt1[x][(j+z)%3]+cnt1[y][j])%mod;
dp1[x][(j+z)%3]=(dp1[x][(j+z)%3]+dp1[y][j]+z*(cnt1[y][j]))%mod;
}
}
}
void dfs2(int x,int fa)
{
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
ll z=edge[i];
if(y==fa) continue;
for(int j=0;j<3;j++)
{
cnt2[y][(j+z)%3]=((cnt1[x][j]-cnt1[y][((j-z)%3+3)%3]+cnt2[x][j])%mod+mod)%mod;
dp2[y][(j+z)%3]=(((dp1[x][j]-dp1[y][((j-z)%3+3)%3]-z*cnt1[y][((j-z)%3+3)%3])%mod+mod)%mod+dp2[x][j]+z*cnt2[y][(z+j)%3])%mod;
}
dfs2(y,x);
}
}
int main(void)
{
int n;
while(scanf("%d",&n)!=EOF)
{
memset(head,0,sizeof(head));
memset(dp1,0,sizeof(dp1));
memset(dp2,0,sizeof(dp2));
memset(cnt1,0,sizeof(cnt1));
memset(cnt2,0,sizeof(cnt2));
tot=1;
int x,y;
ll z;
for(int i=1;i<n;i++)
{
scanf("%d%d%lld",&x,&y,&z);
x++,y++;
add(x,y,z);
add(y,x,z);
}
dfs1(1,0);
dfs2(1,0);
ll ans[3];
ans[0]=ans[1]=ans[2]=0;
for(int i=1;i<=n;i++)
{
for(int j=0;j<3;j++)
ans[j]=(ans[j]+dp1[i][j]+dp2[i][j])%mod;
}
printf("%lld %lld %lld\n",ans[0],ans[1],ans[2]);
}
return 0;
}
现在想一下,其实点分治也可以写,当时就觉得这个题的数据量挺像点分治的。
不过后来被否了。于是写了dp却没有过。话说一上来就点分的话说不定就过了。
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<cmath>
#include<queue>
#include<map>
#include<vector>
#define ll long long
#define llu unsigned ll
using namespace std;
const int maxn=21000;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
int head[maxn],ver[maxn],nt[maxn];
ll edge[maxn];
int max_part[maxn],_size[maxn];
bool ha[maxn];
int tot=1,root,max_size,n;
ll ans0,ans1,ans2;
ll cnt[3],dis[3],ans[3];
void init(void)
{
memset(head,0,sizeof(head));
memset(ha,0,sizeof(ha));
tot=1,ans0=ans1=ans2=0;
}
void add(int x,int y,ll z)
{
ver[++tot]=y,edge[tot]=z;
nt[tot]=head[x],head[x]=tot;
}
void dfs_size(int x,int fa)
{
_size[x]=1,max_part[x]=0;
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(ha[y]||y==fa) continue;
dfs_size(y,x);
_size[x]+=_size[y];
max_part[x]=max(max_part[x],_size[y]);
}
}
void dfs_root(int now_root,int x,int fa)
{
max_part[x]=max(max_part[x],_size[now_root]-_size[x]);
if(max_size>max_part[x])
{
max_size=max_part[x];
root=x;
}
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(ha[y]||y==fa) continue;
dfs_root(now_root,y,x);
}
}
void dfs_dis(int x,int fa,ll diss)
{
cnt[diss%3]++,cnt[diss%3]%=mod;
dis[diss%3]+=diss,dis[diss%3]%=mod;
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
ll z=edge[i];
if(ha[y]||y==fa) continue;
dfs_dis(y,x,diss+z);
}
}
void get_num(int x,ll d)
{
dfs_dis(x,0,d);
ans[0]=ans[1]=ans[2]=0;
cnt[0]=cnt[1]=cnt[2]=0;
dis[0]=dis[1]=dis[2]=0;
dfs_dis(x,0,d);
for(int i=0;i<3;i++)
{
for(int j=0;j<3;j++)
ans[(i+j)%3]=(ans[(i+j)%3]+cnt[i]*dis[j]+cnt[j]*dis[i])%mod;
}
}
void dfs(int x)
{
max_size=inf;
dfs_size(x,-1);
dfs_root(x,x,-1);
get_num(root,0);
ans0=(ans0+ans[0])%mod;
ans1=(ans1+ans[1])%mod;
ans2=(ans2+ans[2])%mod;
ha[root]=1;
for(int i=head[root];i;i=nt[i])
{
int y=ver[i];
ll z=edge[i];
if(ha[y]) continue;
get_num(y,z);
ans0=((ans0-ans[0])%mod+mod)%mod;
ans1=((ans1-ans[1])%mod+mod)%mod;
ans2=((ans2-ans[2])%mod+mod)%mod;
dfs(y);
}
}
int main(void)
{
while(scanf("%d",&n)!=EOF)
{
init();
int x,y;
ll z;
for(int i=1;i<n;i++)
{
scanf("%d%d%lld",&x,&y,&z);
x++,y++;
add(x,y,z);add(y,x,z);
}
dfs(1);
printf("%lld %lld %lld\n",ans0,ans1,ans2);
}
return 0;
}
不过网上大多数的点分治代码,求重心的时候就用了一遍dfs。实际上效率差不多。
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<cmath>
#include<queue>
#include<map>
#include<vector>
#define ll long long
#define llu unsigned ll
using namespace std;
const int maxn=21000;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
int head[maxn],ver[maxn],nt[maxn],_size[maxn];;
ll edge[maxn];
bool ha[maxn];
int tot=1,root,max_size,n,si;
ll ans0,ans1,ans2;
ll cnt[3],dis[3],ans[3];
void init(void)
{
memset(head,0,sizeof(head));
memset(ha,0,sizeof(ha));
tot=1,ans0=ans1=ans2=0;
}
void add(int x,int y,ll z)
{
ver[++tot]=y,edge[tot]=z;
nt[tot]=head[x],head[x]=tot;
}
void dfs_root(int x,int fa)
{
_size[x]=1;
int max_part=0;
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(ha[y]||y==fa) continue;
dfs_root(y,x);
_size[x]+=_size[y];
max_part=max(max_part,_size[y]);
}
max_part=max(max_part,si-_size[x]);
if(max_part<max_size) max_size=max_part,root=x;
}
void dfs_dis(int x,int fa,ll diss)
{
cnt[diss%3]++,cnt[diss%3]%=mod;
dis[diss%3]+=diss,dis[diss%3]%=mod;
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
ll z=edge[i];
if(ha[y]||y==fa) continue;
dfs_dis(y,x,diss+z);
}
}
void get_num(int x,ll d)
{
dfs_dis(x,0,d);
ans[0]=ans[1]=ans[2]=0;
cnt[0]=cnt[1]=cnt[2]=0;
dis[0]=dis[1]=dis[2]=0;
dfs_dis(x,0,d);
for(int i=0;i<3;i++)
{
for(int j=0;j<3;j++)
ans[(i+j)%3]=(ans[(i+j)%3]+cnt[i]*dis[j]+cnt[j]*dis[i])%mod;
}
}
void dfs(int x)
{
get_num(x,0);
ans0=(ans0+ans[0])%mod;
ans1=(ans1+ans[1])%mod;
ans2=(ans2+ans[2])%mod;
ha[x]=1;
int totsi=si;
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
ll z=edge[i];
if(ha[y]) continue;
get_num(y,z);
ans0=((ans0-ans[0])%mod+mod)%mod;
ans1=((ans1-ans[1])%mod+mod)%mod;
ans2=((ans2-ans[2])%mod+mod)%mod;
max_size=inf;
si=_size[y]>_size[x]?totsi-_size[x]:_size[y];
dfs_root(y,0);
dfs(root);
}
}
int main(void)
{
while(scanf("%d",&n)!=EOF)
{
init();
int x,y;
ll z;
for(int i=1;i<n;i++)
{
scanf("%d%d%lld",&x,&y,&z);
x++,y++;
add(x,y,z);add(y,x,z);
}
si=n;
max_size=inf;
dfs_root(1,0);
dfs(root);
printf("%lld %lld %lld\n",ans0,ans1,ans2);
}
return 0;
}