[编程题] 最长树链
时间限制:1秒
空间限制:32768K
树链是指树里的一条路径。美团外卖的形象代言人袋鼠先生最近在研究一个特殊的最长树链问题。现在树中的每个点都有一个正整数值,他想在树中找出最长的树链,使得这条树链上所有对应点的值的最大公约数大于1。请求出这条树链的长度。
输入描述:
第1行:整数n(1 ≤ n ≤ 100000),表示点的个数。 第2~n行:每行两个整数x,y表示xy之间有边,数据保证给出的是一棵树。 第n+1行:n个整数,依次表示点1~n对应的权值(1 ≤ 权值 ≤ 1,000,000,000)。
输出描述:
满足最长路径的长度
输入例子:
4 1 2 1 3 2 4 6 4 5 2
输出例子:
3
1具体就是枚举所以质因子,然后dfs的时候只走有当前这个质因子的点。。。这样的话正面看来复杂度貌似很大,最多有10W个质因子,每次dfs复杂度上限也是10W,,貌似会超时? 这个时候要从另外一个角度计算复杂度啦。我们对一个数来考虑,他会被几次dfs遍历到?。很显然就是它的质因子的数量次,,一个数的质因子数比logn还要小很多。所以所有数只会被遍历到最多 nlogn次,复杂度也差不多nlogn级别的,就不会超时了
#include<iostream>
#include<cstdio>
#include<math.h>
#include<algorithm>
#include<map>
#include<set>
#include<bitset>
#include<stack>
#include<queue>
#include<string.h>
#include<cstring>
#include<vector>
#include<time.h>
#include<stdlib.h>
using namespace std;
#define INF 0x3f3f3f3f
#define INFLL 0x3f3f3f3f3f3f3f3f
#define FIN freopen("input.txt","r",stdin)
#define mem(x,y) memset(x,y,sizeof(x))
typedef unsigned long long ULL;
typedef long long LL;
#define fuck(x) cout<<"q"<<endl;
#define MX 111111
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
typedef pair<pair<int,int>,int> PIII;
typedef pair<int,int> PII;
const double eps=1e-8;
int n;
int val[MX];
bool isprime[MX];
int prime[MX],prime_cnt;
void prime_init()
{
prime_cnt=0;
mem(isprime,1);
for(int i=2; i<MX; i++)
{
if(isprime[i])prime[prime_cnt++]=i;
for(int j=0; j<prime_cnt&&prime[j]*i<MX; j++)
{
isprime[prime[j]*i]=0;
if(i%prime[j]==0)break;
}
}
}
int head[MX],edge_cnt;
struct Edge
{
int nxt,to;
} E[MX*2];
void edge_init()
{
mem(head,-1);
edge_cnt=0;
}
void edge_add(int u,int v)
{
E[edge_cnt].nxt=head[u];
E[edge_cnt].to=v;
head[u]=edge_cnt++;
}
int vis[MX];
int num[MX];
map<int,vector<int> >mp;
int w;
int dfs(int u,int fa)
{
num[u]=1;
int MM=0,M=0;
int ans=0;
for(int i=head[u]; ~i; i=E[i].nxt)
{
int v=E[i].to;
if(v==fa||val[v]%w)continue;
ans=max(ans,dfs(v,u));
if(num[v]>MM)M=MM,MM=num[v];
else if(num[v]>M)M=num[v];
}
num[u]=M+MM+1;
return max(ans,num[u]);
}
int solve(int u)
{
w=u;
int ans=0;
for(auto i:mp[u])
{
if(vis[i]==0)ans=max(ans,dfs(i,-1));
}
for(auto i:mp[u])vis[i]=0;
return ans;
}
int main()
{
prime_init();
FIN;
while(cin>>n)
{
edge_init();
mp.clear();
for(int i=1; i<n; i++)
{
int u,v;
scanf("%d%d",&u,&v);
edge_add(u,v);
edge_add(v,u);
}
for(int i=1; i<=n; i++)
{
scanf("%d",&val[i]);
int x=val[i];
for(int j=0; prime[j]<=sqrt(x+0.5); j++)
{
if(x%prime[j]==0)mp[prime[j]].push_back(i);
while(x%prime[j]==0)x/=prime[j];
}
if(x!=1)mp[x].push_back(i);
}
int ans=0;
for(auto i:mp)
{
ans=max(ans,solve(i.first));
}
cout<<ans<<endl;
}
return 0;
}