传送门
社论(题解):
首先长剖重剖都考虑过了,并没有办法支持快速合并,边分更不用说了,权值在边上怎么边分怎么蛋疼。
考虑点分,我们知道如果一个回文串过了重心,他要么就是重心延伸出去的回文前缀,要么它被重心分成两段,短的一定是长的后缀。
很显然我们考虑用AC自动机求出这种后缀关系。
那么现在问题变成了,当前分治重心到这个点的串(设长度为 l l l)如果存在长度为 d d d的回文前缀,我们需要求它在fail树上长度为 l − d l-d l−d的祖先出现了多少次。
窒息的是,回文前缀可能有一大堆。
但是根据broder理论我们知道,这些回文前缀可以被表示为不超过 log \log log个等差数列。
假设我们有某一个公差为 d d d,首项为 a 0 a_0 a0,末项为 a t a_t at的等差数列,那么相当于我们需要求fail树上,长度 % d \%d %d与 a 0 a_0 a0相同,且长度在 l − a 0 l-a_0 l−a0和 l − a t l-a_t l−at之间的祖先个数。
对于 d ≤ n d\leq \sqrt n d≤n的我们可以考虑用一个桶来记录,也就是记录 % i = j \%i=j %i=j的长度在祖先上出现了多少次,对于公差大的,直接暴力询问就行了。
但是,实际上由于在计算不同重心的时候本身回文串的切割会非常鬼畜,所以实际上我们只对公差小于等于 4 4 4的记录就能够达到最优效率了。
回文前缀直接用哈希求就行了。
代码:
#include<bits/stdc++.h>
#define ll long long
#define re register
#define gc get_char
#define cs const
namespace IO{
inline char get_char(){
static cs int Rlen=1<<22|1;
static char buf[Rlen],*p1,*p2;
return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;
}
template<typename T>
inline T get(){
char c;
while(!isdigit(c=gc()));T num=c^48;
while(isdigit(c=gc()))num=(num+(num<<2)<<1)+(c^48);
return num;
}
inline int getint(){return get<int>();}
}
using namespace IO;
using std::cerr;
using std::cout;
cs int N=50004;
cs int SqrtN=4;
int n;
cs ll mod=1e9+7;
cs ll B=47;
ll bas[N];
inline void init_bas(){
bas[0]=1;
for(int re i=1;i<N;++i)bas[i]=bas[i-1]*B%mod;
}
namespace AC{
int son[N][2],fail[N],now;
int len[N],cnt[N];
inline int newnode(int x){
++now;
len[now]=x;
return now;
}
std::vector<int> G[N];
inline void build_fail(){
std::queue<int> q;
for(int re i=0;i<2;++i)if(son[1][i])q.push(son[1][i]),fail[son[1][i]]=1;
else son[1][i]=1;
while(!q.empty()){
int u=q.front();q.pop();
for(int re i=0;i<2;++i)
if(son[u][i])fail[son[u][i]]=son[fail[u]][i],q.push(son[u][i]);
else son[u][i]=son[fail[u]][i];
}
for(int re i=2;i<=now;++i)G[fail[i]].push_back(i);
}
struct prefix{
int l,r,d;
prefix(){}
prefix(int _l,int _r,int _d):l(_l),r(_r),d(_d){}
};
std::vector<prefix> vec[N];
void dfs_pal(int u,ll h1,ll h2,cs std::vector<prefix> &cur){
vec[u]=cur;
if(len[u]>0&&h1==h2){
if(cur.empty())vec[u].push_back(prefix(len[u],len[u],len[u]));
else {
auto &p=vec[u].back();
if(p.d==len[u]-p.r)p.r=len[u];
else vec[u].push_back(prefix(len[u],len[u],len[u]-p.r));
}
}
for(int re i=0;i<2;++i)if(son[u][i])
dfs_pal(son[u][i],(h1*B+i)%mod,(h2+bas[len[u]]*i)%mod,vec[u]);
}
int ans[N];
int st[N],top;
inline int find(int x){
int l=1,r=top,mid;
while(l<=r)len[st[mid=l+r>>1]]<=x?l=mid+1:r=mid-1;
return st[r];
}
struct Query{
int a,b,id,tag;
Query(int _a,int _b,int _id,int _tag):a(_a),b(_b),id(_id),tag(_tag){}
};
std::vector<Query> q[N];
int c[SqrtN+3][SqrtN+3];
int d[N];
void dfs(int u){
for(int re i=1;i<=SqrtN;++i)c[i][len[u]%i]+=cnt[u];
d[len[u]]+=cnt[u];
st[++top]=u;
for(cs auto &p:vec[u]){
if(p.d<=SqrtN){
int x=find(len[u]-p.r-1);
int y=find(len[u]-p.l);
int k=(len[u]-p.l)%p.d;
if(x!=y){
if(x>0)q[x].push_back(Query(p.d,k,u,-1));
if(y>0)q[y].push_back(Query(p.d,k,u,1));
}
}
else {
for(int a=p.l;a<=p.r;a+=p.d)
q[u].push_back(Query(0,len[u]-a,u,+1));
}
}
for(int re v:G[u])dfs(v);
for(auto &q:(AC::q[u])){
if(q.a)ans[q.id]+=q.tag*c[q.a][q.b];
else ans[q.id]+=q.tag*d[q.b];
}
for(int re i=1;i<=SqrtN;++i)c[i][len[u]%i]-=cnt[u];
d[len[u]]-=cnt[u];
st[top--]=0;
}
ll solve(){
dfs(1);ll res=0;
for(int re i=2;i<=now;++i)res+=(ll)cnt[i]*ans[i]+(ll)cnt[i]*(cnt[i]-1)/2;
return res;
}
ll calc(){
dfs_pal(1,0,0,std::vector<prefix>(0));
build_fail();
return solve();
}
inline void clear(){
while(now){
son[now][0]=son[now][1]=fail[now]=len[now]=cnt[now]=ans[now]=0;
vec[now].clear();
G[now].clear();
q[now].clear();
--now;
}
newnode(0);
}
}
namespace TDC{
struct edge{
int to,w;
};
std::vector<edge> G[N];
inline void addedge(int u,int v,int w){
G[u].push_back((edge){v,w});
G[v].push_back((edge){u,w});
}
bool ban[N];
int siz[N];
int tot,mx,g;
ll ans;
void get_siz(int u,int fa){
siz[u]=1;
for(edge &e:G[u])if(!ban[e.to]&&e.to!=fa)get_siz(e.to,u),siz[u]+=siz[e.to];
}
void find_G(int u,int fa){
int mx_u=tot-siz[u];
for(edge &e:G[u])if(!ban[e.to]&&e.to!=fa)find_G(e.to,u),mx_u=std::max(mx_u,siz[e.to]);
if(mx_u<mx)mx=mx_u,g=u;
}
int get_G(int u){
get_siz(u,0);
tot=siz[u];
g=-1,mx=0x3f3f3f3f;
find_G(u,0);
assert(~g);
return g;
}
void dfs(int u,int p,int nd){
AC::cnt[nd]++;
for(edge &e:G[u])if(e.to!=p&&!ban[e.to]){
if(!AC::son[nd][e.w])AC::son[nd][e.w]=AC::newnode(AC::len[nd]+1);
dfs(e.to,u,AC::son[nd][e.w]);
}
}
inline ll calc(int u,int w=-1){
AC::clear();
if(~w){
AC::son[1][w]=AC::newnode(1);
dfs(u,0,AC::son[1][w]);
}
else dfs(u,0,1);
return AC::calc();
}
inline void solve_G(int u){
ban[u]=true;
ans+=calc(u);
for(edge &e:G[u])if(!ban[e.to]){
ans-=calc(e.to,e.w);
int t=get_G(e.to);
solve_G(t);
}
}
inline void solve(){
int u=get_G(1);
solve_G(u);
cout<<ans<<"\n";
}
}
signed main(){
init_bas();
// freopen("pal.in","r",stdin);freopen("pal.out","w",stdout);
n=getint();
for(int re i=1;i<n;++i){
int u=getint(),v=getint(),w=getint();
TDC::addedge(u,v,w);
}
TDC::solve();
return 0;
}