链接
感想
写了一下午终于通过了,用一句有点中二的话来表达我此刻的心情:
これが、我々の勝利!
好了下面正经点
题解
第一步找直径,我们要找的是那种支链最多的直径。我问了计蒜客管理员,他说这题数据不存在支链数相同的情况。
出题人的做法是找直径的时候把支链数作为第二关键字,这个做法固然很神,但是不好推广。
我的做法比较麻烦,但是容易推广,下面讲讲我的做法:
首先要明白一点,当直径上的点有奇数个时,所有直径都经过一个公共点,且这个公共点是所有直径的中点。当直径上的点有偶数个,那么所有直径都经过一条公共边,且这个公共边就是所有直径最中间那条边(此时直径的边数是奇数所以存在最中间的边)。
当直径上的点数是奇数时:
我找到中心点 t t t,然后以 t t t为根节点建树,此时我只需要在每棵子树选出深度最大的点,这些点当中再取到根节点支链数最多的那个。每棵子树这样的点我拿出来,然后以支链数为关键字排序,取支链数最大的两个,这就是满足条件的直径了。
当直径上的点数是偶数时:
这个时候中心边是 ( a , b ) (a,b) (a,b),这条边把树分成两个部分,一个部分以 a a a为根节点,另一个部分以 b b b为根节点
分别建立有根树,分别在每棵树深度最大的点中取出支链数最多的那个。两棵子树得到两个这样的点,这两个点所对应的路径就是答案。
其余部分想必很容易处理,这里就不多做解释了
代码
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 1000010
#define maxe 1000010
#define cl(x) memset(x,0,sizeof(x))
#define rep(i,a,b) for(i=a;i<=b;i++)
#define drep(i,a,b) for(i=a;i>=b;i--)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
ll n, deg[maxn], cnt[maxn], belong[maxn];
vector<ll> tong[maxn];
struct Graph
{
int etot, head[maxn], to[maxe], next[maxe], w[maxe];
void clear(int N)
{
for(int i=1;i<=N;i++)head[i]=0;
etot=0;
}
void adde(int a, int b, int c=0){to[++etot]=b;w[etot]=c;next[etot]=head[a];head[a]=etot;}
#define forp(_,__) for(int p=__.head[_];p;p=__.next[p])
}G;
struct Easy_Tree
{
int depth[maxn], dist[maxn], tid[maxn], rtid[maxn], tim, size[maxn], rev[maxn], fa[maxn];
void dfs(int pos, int pre, Graph& G)
{
fa[pos]=pre;
tid[pos]=++tim;
rev[tid[pos]]=pos;
size[pos]=1;
forp(pos,G)if(G.to[p]!=pre)
{
depth[G.to[p]]=depth[pos]+1;
dist[G.to[p]]=dist[pos]+G.w[p];
dfs(G.to[p],pos,G);
size[pos]+=size[G.to[p]];
}
rtid[pos]=tim;
}
void run(Graph& G, int root)
{
tim=0;
depth[root]=1;
dfs(root,0,G);
}
}et;
void dfs(ll u, ll fa, ll now)
{
cnt[u] = cnt[fa] + max(0ll,deg[u]-2);
belong[u] = now;
forp(u,G)
{
ll v(G.to[p]); if(v==fa)continue;
if(now==0)dfs(v,u,v);
else dfs(v,u,now);
}
}
#define mod 998244353ll
vector<pll> lis;
bool cmp(ll a, ll b){return cnt[a]>cnt[b];}
int main()
{
freopen("chemistry.in","r",stdin);
freopen("chemistry.out","w",stdout);
ll mxi, i;
n = read();
rep(i,1,n-1)
{
ll u=read(), v=read();
G.adde(u,v), G.adde(v,u);
deg[u]++, deg[v]++;
}
et.run(G,1);
mxi=1;
rep(i,2,n)if(et.depth[i]>et.depth[mxi])mxi=i;
de(mxi);
et.run(G,mxi);
ll mx=-1, ed;
rep(i,1,n)mx=max(mx,(ll)et.depth[i]);
rep(i,1,n)if(et.depth[i]==mx)ed=i; //随便找一条直径
ll A, B;
if(mx%2==0) //所有直径有一条公共边
{
ll t=ed, res=mx/2-1;
while(res--)t=et.fa[t]; //(t,fa[t])就是这条公共边
ll a=t, b=et.fa[t];
et.depth[a] = et.depth[b] = 1;
et.dfs(a,b,G), et.dfs(b,a,G);
dfs(a,b,a), dfs(b,a,b);
vector<ll> v1, v2;
rep(i,1,n)if(et.depth[i]==mx/2)
{
if(belong[i]==a)v1.push_back(i);
else v2.push_back(i);
}
sort(v1.begin(),v1.end(),cmp);
sort(v2.begin(),v2.end(),cmp);
A = v1[0], B = v2[0];
}
else //所有直径有一个公共点
{
ll t=ed, res=mx/2;
vector<ll> v0;
while(res--)t=et.fa[t]; //t就是这个公共点
de(t);
dfs(t,0,0);
et.run(G,t);
rep(i,1,n)if(et.depth[i]==mx/2+1)tong[belong[i]].push_back(i);
rep(i,1,n)
{
sort(tong[i].begin(),tong[i].end(),cmp);
if(!tong[i].empty())v0.push_back(tong[i].at(0));
}
sort(v0.begin(),v0.end(),cmp);
A = v0[0], B = v0[1];
}
//经过以上过程,主链AB已求出
et.run(G,A);
ll t = B, L = et.depth[B]; //L是主链长度
ll ans=0, pre=0;
ll tot=0;
t = B;
while(t!=A)
{
forp(t,G)
{
ll v(G.to[p]);
if(v==et.fa[t] or v==pre)continue;
lis.push_back( pll(et.size[v],et.depth[t]) );
}
pre = t;
t = et.fa[t];
}
sort(lis.begin(),lis.end());
pll p1, p2;
p1=pll(0,linf), p2=pll(0,-linf);
rep(i,0,lis.size()-1ll)
{
if(lis[i].se<p1.se or lis[i].se==p1.se and lis[i].fi<p1.fi)p1=lis[i];
if(lis[i].se>p2.se or lis[i].se==p2.se and lis[i].fi<p2.fi)p2=lis[i];
}
if(p1.se<L-p2.se+1 or p1.se==L-p2.se+1 and p1.fi<p2.fi)
{
rep(i,0,lis.size()-1ll)
(ans+=(i+1)*lis[i].fi*lis[i].se)%=mod;
}
else
{
rep(i,0,lis.size()-1ll)lis[i].se = L-lis[i].se+1;
sort(lis.begin(),lis.end());
rep(i,0,lis.size()-1ll)(ans+=(i+1)*lis[i].fi*lis[i].se)%=mod;
}
printf("%lld\n%lld",ans,L);
return 0;
}