题意
给定一个长度为n的数组a[1..n],有一幅完全图,满足(u,v)的边权为a[u] xor a[v]
求边权和最小的生成树,你需要输出边权和还有方案数对1e9+7取模的值
1<=n<=10^5
0<=a[i]<2^30
分析
根据异或的性质,我们可以从高位到低位贪心。对于当前位,我们可以把点集分为当前位是1和是0两部分,显然最优的连边方案是两个集合自成一个连通块,然后在这两个集合之间连一条边。
那么先把所有数的字典树建出来,然后在字典树上面分治。
对于当前节点,先递归下去求左右子树分别做生成树的代价。然后通过类似线段树合并的方式来求在左右子树中分别取一个点的异或最小值以及方案。
注意当递归到叶节点时,若该叶节点的size大于2,表明该叶节点的生成树不止一种。具体来讲应该是有
sizesize−2
s
i
z
e
s
i
z
e
−
2
种。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define mp(x,y) make_pair(x,y)
using namespace std;
typedef long long LL;
typedef pair<int,int> pi;
const int N=100005;
const int inf=0x7fffffff;
const int MOD=1000000007;
int n,sz,sum,bin[35],rt;
LL ans;
struct tree{int l,r,s;}t[N*30];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int ksm(int x,int y)
{
int ans=1;
while (y)
{
if (y&1) ans=(LL)ans*x%MOD;
x=(LL)x*x%MOD;y>>=1;
}
return ans;
}
void mod(int &x) {x-=x>=MOD?MOD:0;}
void ins(int &d,int dep,int v)
{
if (!d) d=++sz;
t[d].s++;
if (dep<0) return;
if (v&bin[dep]) ins(t[d].l,dep-1,v);
else ins(t[d].r,dep-1,v);
}
pi get(int x,int y,int dep)
{
if (dep<0) return mp(0,(LL)t[x].s*t[y].s%MOD);
pi ans=mp(inf,0);
if (t[x].l&&t[y].l||t[x].r&&t[y].r)
{
if (t[x].l&&t[y].l)
{
pi u=get(t[x].l,t[y].l,dep-1);
if (u.first<ans.first) ans=u;
else if (u.first==ans.first) mod(ans.second+=u.second);
}
if (t[x].r&&t[y].r)
{
pi u=get(t[x].r,t[y].r,dep-1);
if (u.first<ans.first) ans=u;
else if (u.first==ans.first) mod(ans.second+=u.second);
}
}
else
{
if (t[x].l&&t[y].r)
{
pi u=get(t[x].l,t[y].r,dep-1);u.first+=bin[dep];
if (u.first<ans.first) ans=u;
else if (u.first==ans.first) mod(ans.second+=u.second);
}
if (t[x].r&&t[y].l)
{
pi u=get(t[x].r,t[y].l,dep-1);u.first+=bin[dep];
if (u.first<ans.first) ans=u;
else if (u.first==ans.first) mod(ans.second+=u.second);
}
}
return ans;
}
void solve(int d,int dep)
{
if (dep<0) {if (t[d].s>2) sum=(LL)sum*ksm(t[d].s,t[d].s-2)%MOD;return;}
if (t[d].l) solve(t[d].l,dep-1);
if (t[d].r) solve(t[d].r,dep-1);
if (t[d].l&&t[d].r)
{
pi u=get(t[d].l,t[d].r,dep-1);
ans+=u.first+bin[dep];sum=(LL)sum*u.second%MOD;
}
}
int main()
{
n=read();
bin[0]=1;
for (int i=1;i<=30;i++) bin[i]=bin[i-1]*2;
for (int i=1,x;i<=n;i++) x=read(),ins(rt,30,x);
sum=1;
solve(rt,30);
printf("%lld\n%d",ans,sum);
return 0;
}