由于现在还没有入到主题库里,暂时没有题号和题目链接
题目大意
给出一个长度为 n n n 的序列 a a a,你需要找出它的一个子序列 b b b(设 b b b 的长度为 m m m),并满足以下条件
- 对于所有 ( 1 ≤ i ≤ m ) (1\le i \le m) (1≤i≤m), b i b_i bi 要么是 b b b 中最大的元素,要么存在一个 b j > b i b_j>b_i bj>bi,满足 b i + b j + gcd ( b i , b j ) = lcm ( b i , b j ) b_i+b_j+\gcd(b_i,b_j)=\operatorname{lcm}(b_i,b_j) bi+bj+gcd(bi,bj)=lcm(bi,bj)
-
∑
i
=
1
m
b
i
\sum\limits_{i=1}^m b_i
i=1∑mbi 尽可能大
n ≤ 3 × 1 0 5 n\le 3 \times 10^5 n≤3×105
前置芝士
解题思路
首先,我们先将
b
i
+
b
j
+
gcd
(
b
i
,
b
j
)
=
lcm
(
b
i
,
b
j
)
b_i+b_j+\gcd(b_i,b_j)=\operatorname{lcm}(b_i,b_j)
bi+bj+gcd(bi,bj)=lcm(bi,bj) 化简
b
i
+
b
j
+
gcd
=
lcm
gcd
×
gcd
b_i+b_j+\gcd=\frac{\operatorname{lcm}}{\gcd} \times \gcd
bi+bj+gcd=gcdlcm×gcd
b
i
+
b
j
=
(
lcm
gcd
−
1
)
×
gcd
b_i+b_j=(\frac{\operatorname{lcm}}{\gcd}-1) \times \gcd
bi+bj=(gcdlcm−1)×gcd
b
i
gcd
+
b
j
gcd
=
(
lcm
gcd
−
1
)
\frac{b_i}{\gcd}+\frac{b_j}{\gcd}=(\frac{\operatorname{lcm}}{\gcd}-1)
gcdbi+gcdbj=(gcdlcm−1)
到了这,你应该会发现
b
i
gcd
×
b
j
gcd
=
b
i
⋅
b
j
gcd
2
=
lcm
gcd
\frac{b_i}{\gcd}\times\frac{b_j}{\gcd}=\frac{b_i\cdot b_j}{\gcd^2}=\frac{\operatorname{lcm}}{\gcd}
gcdbi×gcdbj=gcd2bi⋅bj=gcdlcm,原因是
lcm
⋅
gcd
=
b
i
⋅
b
j
\operatorname{lcm} \cdot \gcd=b_i \cdot b_j
lcm⋅gcd=bi⋅bj
那么,我们就得到了
b
i
gcd
+
b
j
gcd
=
(
b
i
gcd
⋅
b
j
gcd
−
1
)
\frac{b_i}{\gcd}+\frac{b_j}{\gcd}=(\frac{b_i}{\gcd} \cdot \frac{b_j}{\gcd}-1)
gcdbi+gcdbj=(gcdbi⋅gcdbj−1)
这时候我就赶紧跑去打了个表QwQ,发现只有
(
2
,
3
)
(2,3)
(2,3) 这个组合满足两数之和等于两数之积减一。也就是说,当且仅当
b
i
b
j
=
2
3
\frac{b_i}{b_j}=\frac{2}{3}
bjbi=32 的时候,原式才成立。
由此,题目就被我们化成了这样:
b
i
b_i
bi 要么是
b
b
b 中最大的,要么存在一个
b
j
b_j
bj,满足
b
i
b
j
=
2
3
\frac{b_i}{b_j}=\frac{2}{3}
bjbi=32
剩下的就很简单了,直接计算最长链就好了
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<map>
using namespace std;
const int Maxn=300000+10;
map <int,int> c;
map <int,bool> vis;
int a[Maxn],n;
long long ans;
inline int read()
{
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0' && ch<='9')s=(s<<3)+(s<<1)+(ch^48),ch=getchar();
return s*w;
}
int main()
{
// freopen("in.txt","r",stdin);
n=read();
for(int i=1;i<=n;++i)
{
a[i]=read();
c[a[i]]++;
}
sort(a+1,a+1+n);
for(int i=1;i<=n;++i)
{
if(vis[a[i]])continue;
register int x=a[i];
register long long tot=0;
while(c[x])
{
tot+=(long long)c[x]*x;
vis[x]=1;
if(x & 1)break;
x=(x>>1)*3;
}
ans=max(ans,tot);
}
printf("%lld\n",ans);
return 0;
}