链接
http://www.lydsy.com/JudgeOnline/problem.php?id=2152
题解
我作大死把取模全都换成减法,结果
WA
了一发。
可以直接点分治,一棵子树一棵子树合并,然后统计下答案就行了。
用点分治做就太水了,其实还可以用树形
dp
。
设
f[i][j]
表示一段在
i
另一端在
转移比较容易,直接算完自己然后去转移父亲就行了。
问题是答案怎么统计?
可以容斥,对于每个点,先把
f[i][0]2+2×f[i][1]f[i][2]
计入答案,然后把出现在一棵子树里的减去。
代码
//点分治
#include <cstdio>
#include <algorithm>
#define maxn 20010
#define forp for(int p=head[pos];p;p=nex[p])if(to[p]^pre and !grey[to[p]])
using namespace std;
int N, head[maxn], to[maxn<<1], nex[maxn<<1], w[maxn<<1], cnt[5], tot, dist[maxn],
list[maxn], ans, deep[maxn], size[maxn], G, sumG, grey[maxn];
inline void adde(int a, int b, int v)
{to[++tot]=b;w[tot]=v;nex[tot]=head[a];head[a]=tot;}
inline void ins(int x){cnt[dist[x]]++;}
inline void erase(int x){cnt[dist[x]]--;}
inline void calc()
{for(int i=1;i<=*list;i++)
if(dist[list[i]]!=0)ans+=cnt[3-dist[list[i]]];
else ans+=cnt[0];}
int dfs(int pos, int pre)
{
size[pos]=1;
if(dist[pos]>=3)dist[pos]-=3;
list[++*list]=pos;
forp deep[to[p]]=deep[pos]+1, dist[to[p]]=dist[pos]+w[p],
size[pos]+=dfs(to[p],pos);
return size[pos];
}
void findG(int pos, int pre, int sum)
{
if(sum<sumG)G=pos, sumG=sum;
forp findG(to[p],pos,sum+*size-(size[to[p]]<<1));
}
inline void solve(int pos)
{
int i;
*list=0, deep[pos]=dist[pos]=0, dfs(pos,-1);
for(i=1,sumG=0;i<=*list;i++)sumG+=deep[list[i]];
*size=size[pos], findG(G=pos,-1,sumG);
dist[G]=0, ins(G);
grey[G]=1;
for(int p=head[G];p;p=nex[p])
if(!grey[to[p]])
{
*list=0, dist[to[p]]=w[p], dfs(to[p],G);
calc();
for(i=1;i<=*list;i++)ins(list[i]);
}
*list=0, dist[G]=0, dfs(G,-1);
for(i=1;i<=*list;i++)erase(list[i]);
for(int p=head[G];p;p=nex[p])if(!grey[to[p]])solve(to[p]);
}
int gcd(int a, int b){return !b?a:gcd(b,a%b);}
void input()
{
int a, b, v, i;
scanf("%d",&N);
for(i=1;i<N;i++)scanf("%d%d%d",&a,&b,&v),adde(a,b,v%=3),adde(b,a,v);
}
int main()
{
input();
solve(1);
int x=ans*2+N, y=N*N;
printf("%d/%d",x/gcd(x,y),y/gcd(x,y));
return 0;
}
//树形dp
#include <cstdio>
#include <algorithm>
#define maxn 200010
using namespace std;
int N, head[maxn], to[maxn<<1], w[maxn<<1], nex[maxn<<1], f[maxn][3], tot, ans;
inline void adde(int a, int b, int v)
{to[++tot]=b;w[tot]=v;nex[tot]=head[a];head[a]=tot;}
inline int mod(int x)
{
if(x<0)return x+3;
if(x>2)return x-3;
return x;
}
inline void dp(int pos, int pre)
{
int p, i;
f[pos][0]=1;
for(p=head[pos];p;p=nex[p])if(to[p]^pre)dp(to[p],pos);
for(p=head[pos];p;p=nex[p])if(to[p]==pre)break;
for(i=0;i<3;i++)f[to[p]][mod(i+w[p])]+=f[pos][i];
ans+=f[pos][0]*f[pos][0]+f[pos][1]*f[pos][2]+f[pos][2]*f[pos][1];
for(p=head[pos];p;p=nex[p])
if(to[p]^pre)
ans-=f[to[p]][mod(-w[p])]*f[to[p]][mod(-w[p])]
+f[to[p]][mod(1-w[p])]*f[to[p]][mod(2-w[p])]
+f[to[p]][mod(2-w[p])]*f[to[p]][mod(1-w[p])];
}
void input()
{
int i, a, b, v;
scanf("%d",&N);
for(i=1;i<N;i++)scanf("%d%d%d",&a,&b,&v),adde(a,b,v%3),adde(b,a,v%3);
}
int gcd(int a, int b){return !b?a:gcd(b,a%b);}
int main()
{
input();
dp(1,-1);
printf("%d/%d",ans/gcd(ans,N*N),N*N/gcd(ans,N*N));
return 0;
}