POJ1741
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
点分治裸题。
点分治的主要思想就是:对这个题来讲的话
我们先随便找一个点 1.既然是分治。可以分为两种情况。
1.点对在1的不同子树内。
2.点对在1的相同子树x内。
同时情况2又是 像这样可以分为两种情况。
我们递归下去每次只计算情况1就好了。
如何计算:可以将以1为根的所有节点的 dep[ ] dfs处理出来。dep[ ]是路径而非深度。
cla(x,now)函数:将这些dep[ ]排序之后 用尺取法两个指针l,r分别从前后扫描。条件是dep[l]+dep[r]。。。我们这样会多计算一些东西。同一颗子树里面dep[u]+dep[v]<=K的点对。但我们这个函数本身计算的就是dep[x]+dep[y]<=K的点对。所以只需要再减去每个子树不合法的情况即可。也就是 cal(edge[i].to,egde[i].len)。看代码就好理解了。
证明可知,随便找着一个点记为root,当root为当前树的中心。时间复杂度较低。getroot(x,fa)函数用来找一个树的重心。
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<iostream>
#define foru(i,a,b) for(int i=a;i<=b;++i)
#define m(a,b) memset(a,b,sizeof a)
#define en '\n'
using namespace std;
typedef long long ll;
const int N=1e4+5,M=1e5+5,INF=0x3f3f3f3f;
template<class T>void rd(T &x)
{
x=0;int f=0;char ch=getchar();
while(ch<'0'||ch>'9') {f|=(ch=='-');ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
x=f?-x:x;
return;
}
struct Edge{int to,len,nex;}edge[N<<1];
int head[N],tot;
void add(int from,int to,int len)
{
edge[++tot]=(Edge){to,len,head[from]};head[from]=tot;
edge[++tot]=(Edge){from,len,head[to]};head[to]=tot;
}
ll ans=0;
int root,sum;
int sz[N],mx[N],vis[N];//sz[i]以i为根的子树节点总个数 mx[i]是以i为根的几叉树最大的儿子节点数目 vis[]标记这个节点是否当过root
void getroot(int x,int fa)//查找树的重心
{
sz[x]=1,mx[x]=0;
for(int i=head[x];i;i=edge[i].nex)
{
int y=edge[i].to;
if(vis[y]||y==fa) continue;
getroot(y,x);
sz[x]+=sz[y];
mx[x]=max(mx[x],sz[y]);
}
mx[x]=max(mx[x],sum-sz[x]);//我觉着这个应该是x通往fa的那一条子树的大小.
if(mx[x]<mx[root]) root=x;//更新root是谁
}
int d[N],dep[N];
void getd(int x,int fa)
{
d[++d[0]]=dep[x];
for(int i=head[x];i;i=edge[i].nex)
{
int y=edge[i].to;
if(vis[y]||y==fa) continue;
dep[y]=dep[x]+edge[i].len;
getd(y,x);
}
}
int K;
int cal(int x,int now)
{
dep[x]=now,d[0]=0;//d[0]是用来记数的,看这个子树内共有多少节点.
getd(x,-1);
sort(d+1,d+d[0]+1);
int res=0,l=1,r=d[0];
while(l<r)//尺取法
{
if(d[l]+d[r]>K) r--;
else res+=r-l,l++;
}
return res;
}
void work(int x)
{
ans+=cal(x,0);//cal(x,now)会处理以x为根的树 横跨根的点对.(但是我们会多算上不合法的点对)
//比如同在x的一颗子树内,但是计算的路径并不是按LCA来的 而是为dep[u]+dep[v].就需要把这种情况舍去.
vis[x]=1;//标记这个节点已经用过.
for(int i=head[x];i;i=edge[i].nex)
{
int y=edge[i].to;
if(vis[y]) continue;
ans-=cal(y,edge[i].len);//就是这样舍去.
sum=sz[y],root=0;//sum=sz[y]就是为了查找y无法通往他的fa那个子树
getroot(y,-1);
work(root);//对每个子树进行分治
}
}
int main()
{
int n;
while(rd(n),rd(K),n)
{
m(head,0),tot=0;
foru(i,1,n-1)
{
int u,v,w;rd(u),rd(v),rd(w);
add(u,v,w);
}
m(vis,0),root=0,ans=0,sum=n,mx[0]=INF;
getroot(1,-1);
work(root);
printf("%lld\n",ans);
}
}
P3806 【模板】点分治1
输入样例#1:
2 1
1 2 2
2
输出样例#1:
AYE
对于100%的数据n<=10000,m<=100,c<=10000,K<=10000000
**注意:**这里用尺取法就需要注意了。。。。。我是傻逼。。。
直接 ans[d[i]+d[j]]+=val.这样也行。每次i,j循环是从[1,d[0]]。每次d[0]需要重新算 不能按sz[x]。因为work函数里面第二个cal时 会被别的getrroot改变。
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<iostream>
#define foru(i,a,b) for(int i=a;i<=b;++i)
#define m(a,b) memset(a,b,sizeof a)
#define en '\n'
using namespace std;
typedef long long ll;
const int N=1e4+5,M=1e7+5,INF=0x3f3f3f3f;
template<class T>void rd(T &x)
{
x=0;int f=0;char ch=getchar();
while(ch<'0'||ch>'9') {f|=(ch=='-');ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
x=f?-x:x;
return;
}
struct Edge{int to,len,nex;}edge[N<<1];
int head[N],tot;
void add(int from,int to,int len)
{
edge[++tot]=(Edge){to,len,head[from]};head[from]=tot;
edge[++tot]=(Edge){from,len,head[to]};head[to]=tot;
}
int root,sum;
int sz[N],mx[N],vis[N];
void getroot(int x,int fa)
{
sz[x]=1,mx[x]=0;
for(int i=head[x];i;i=edge[i].nex)
{
int y=edge[i].to;
if(vis[y]||y==fa) continue;
getroot(y,x);
sz[x]+=sz[y];
mx[x]=max(mx[x],sz[y]);
}
mx[x]=max(mx[x],sum-sz[x]);
if(mx[x]<mx[root]) root=x;
}
int d[N],dep[N];
void getd(int x,int fa)
{
d[++d[0]]=dep[x];
for(int i=head[x];i;i=edge[i].nex)
{
int y=edge[i].to;
if(vis[y]||y==fa) continue;
dep[y]=dep[x]+edge[i].len;
getd(y,x);
}
}
int ans[M];
int cal(int x,int now,int val)
{
dep[x]=now,d[0]=0;
getd(x,-1);
foru(i,1,d[0])
foru(j,1,d[0])
if(i!=j) ans[d[i]+d[j]]+=val;
}
void work(int x)
{
cal(x,0,1);
vis[x]=1;
for(int i=head[x];i;i=edge[i].nex)
{
int y=edge[i].to;
if(vis[y]) continue;
cal(y,edge[i].len,-1);
sum=sz[y],root=0;
getroot(y,-1);
work(root);
}
}
int main()
{
int n,m;rd(n),rd(m);
foru(i,1,n-1)
{
int u,v,w;rd(u),rd(v),rd(w);
add(u,v,w);
}
root=0,mx[0]=INF,sum=n;
getroot(1,-1);
work(root);
foru(i,1,m)
{
int x;rd(x);
puts(ans[x]?"AYE":"NAY");
}
}
BZOJ 2152
Sample Input
5
1 2 1
1 3 2
1 4 1
2 5 3
Sample Output
13/25
【样例说明】13组点对分别是(1,1) (2,2) (2,3) (2,5) (3,2) (3,3) (3,4) (3,5) (4,3) (4,4) (5,2) (5,3) (5,5)。
【数据规模】对于100%的数据,n<=20000。
思路:统计是否为3的倍数时。我竟然要两次for循环。。我是吃屎了吗
直接统计t[0],t[1],t[2]。代表所求子树中dep[x]%3的节点数。
getd(x,fa)函数dfs时:t[dep[x]]++即可。
然后每次cal的答案就是 t[0]* t[0]+2* t[1]* [2];
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#define m(a,b) memset(a,b,sizeof a)
using namespace std;
typedef long long ll;
const int N=2e4+5,INF=0x3f3f3f3f;
struct Edge{int to,len,nex;}edge[N<<1];
int head[N],tot;
inline void add(int from,int to,int len){
edge[++tot]=(Edge){to,len,head[from]};head[from]=tot;
edge[++tot]=(Edge){from,len,head[to]};head[to]=tot;
}
int root,sum;
int sz[N],mx[N],vis[N];
inline void getroot(int x,int fa){
sz[x]=1,mx[x]=0;
for(int i=head[x];i;i=edge[i].nex){
int y=edge[i].to;
if(vis[y]||y==fa) continue;
getroot(y,x),sz[x]+=sz[y];
mx[x]=max(mx[x],sz[y]);
}
mx[x]=max(mx[x],sum-sz[x]);
if(mx[x]<mx[root]) root=x;
}
int dep[N],t[3];
inline void getd(int x,int fa){
t[dep[x]]++;
for(int i=head[x];i;i=edge[i].nex){
int y=edge[i].to;
if(vis[y]||y==fa) continue;
dep[y]=(dep[x]+edge[i].len)%3,getd(y,x);
}
}
inline int cal(int x,int now){
dep[x]=now,t[0]=t[1]=t[2]=0;
getd(x,-1);
return t[0]*t[0]+t[1]*t[2]*2;
}
ll ans;
inline void work(int x){
ans+=cal(x,0),vis[x]=1;
for(int i=head[x];i;i=edge[i].nex)
{
int y=edge[i].to;
if(vis[y]) continue;
ans-=cal(y,edge[i].len);
sum=sz[y],root=0;
getroot(y,-1),work(root);
}
}
ll gcd(ll a,ll b){
return (!b)?a:gcd(b,a%b);
}
int main()
{
int n;scanf("%d",&n);
for(int i=1,u,v,w;i<=n-1;++i)
scanf("%d%d%d",&u,&v,&w),w%=3,add(u,v,w);
root=0,mx[0]=INF,sum=n;
getroot(1,-1),work(root);
ll c=gcd(ans,(ll)n*n);
printf("%lld/%lld\n",ans/c,(ll)n*n/c);
}