A. 给出n个m维空间中的点,对于点 A(x0,x1....xm−1),B(y0,y1....ym−1) A ( x 0 , x 1 . . . . x m − 1 ) , B ( y 0 , y 1 . . . . y m − 1 ) (0<=xi,yi<4) ( 0 <= x i , y i < 4 ) ,两点的距离定义为 ∑m−1i=0|xi−yi| ∑ i = 0 m − 1 | x i − y i | ,求距离为0~3m中每个值的点对数量
粗略的写一下题解,具体的可以参见 2017国家集训队论文海蜇?海蜇!解题报告步骤二部分
我们可以将每个点看作一个m位的四进制数
我们定义
A∗B=C,C(|x0−y0|,|x1−y1|.....|xm−1−ym−1|)
A
∗
B
=
C
,
C
(
|
x
0
−
y
0
|
,
|
x
1
−
y
1
|
.
.
.
.
.
|
x
m
−
1
−
y
m
−
1
|
)
(同样对点转换成的四进制数
a,b
a
,
b
也这样定义
a∗b
a
∗
b
,可以发现如果在二进制下这样定义就是异或)
令
h[i]
h
[
i
]
表示有多少个点转化成四进制为
i
i
再定义对于数组的乘法,
a=f∗g,a[k]=∑4m−1i=0∑4m−1j=0[i∗j=k]f[i]g[j]
a
=
f
∗
g
,
a
[
k
]
=
∑
i
=
0
4
m
−
1
∑
j
=
0
4
m
−
1
[
i
∗
j
=
k
]
f
[
i
]
g
[
j
]
那么我们只要求出 a=h∗h a = h ∗ h ,就能求出距离为0~3m中每个值的点对数量
我们令
ai
a
i
表示数组
a
a
中四进制最高位为的部分,有
a0=f0∗g0+f1∗g1+f2∗g2+f3∗g3
a
0
=
f
0
∗
g
0
+
f
1
∗
g
1
+
f
2
∗
g
2
+
f
3
∗
g
3
a1=f0∗g1+f1∗g0+f1∗g2+f2∗g1+f2∗g3+f3∗g2
a
1
=
f
0
∗
g
1
+
f
1
∗
g
0
+
f
1
∗
g
2
+
f
2
∗
g
1
+
f
2
∗
g
3
+
f
3
∗
g
2
a2=f0∗g2+f1∗g3+f2∗g0+f3∗g1
a
2
=
f
0
∗
g
2
+
f
1
∗
g
3
+
f
2
∗
g
0
+
f
3
∗
g
1
a3=f0∗g3+f3∗g0
a
3
=
f
0
∗
g
3
+
f
3
∗
g
0
我们可以类似 fwt f w t 的过程,用 f0+f1+f2+f3 f 0 + f 1 + f 2 + f 3 , f0−f1+f2−f3 f 0 − f 1 + f 2 − f 3 , f0+f3 f 0 + f 3 , f0−f3 f 0 − f 3 等优化数组之间乘的次数,把16次优化到6次,然后分治下去处理
复杂度 O(6m) O ( 6 m )
code:
#include<set>
#include<map>
#include<deque>
#include<queue>
#include<stack>
#include<cmath>
#include<ctime>
#include<bitset>
#include<string>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<climits>
#include<complex>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
inline void read(int &x)
{
char c; while(!((c=getchar())>='0'&&c<='9'));
x=c-'0';
while((c=getchar())>='0'&&c<='9') (x*=10)+=c-'0';
}
const int maxn = 1<<18;
void fwt(ll a[],ll f[],ll g[],int N)
{
if(N==1) { a[0]=f[0]*g[0];return; }
int n=N/4;
ll u[n],v[n],ft[n],gt[n];
for(int i=0;i<n;i++) ft[i]=f[i]+f[i+n]+f[i+2*n]+f[i+3*n];
for(int i=0;i<n;i++) gt[i]=g[i]+g[i+n]+g[i+2*n]+g[i+3*n];
fwt(u,ft,gt,n);
for(int i=0;i<n;i++) ft[i]=f[i]-f[i+n]+f[i+2*n]-f[i+3*n];
for(int i=0;i<n;i++) gt[i]=g[i]-g[i+n]+g[i+2*n]-g[i+3*n];
fwt(v,ft,gt,n);
for(int i=0;i<n;i++) a[2*n+i]=(u[i]+v[i])>>1,a[n+i]=(u[i]-v[i])>>1;
for(int i=0;i<n;i++) ft[i]=f[i]+f[i+3*n];
for(int i=0;i<n;i++) gt[i]=g[i]+g[i+3*n];
fwt(u,ft,gt,n);
for(int i=0;i<n;i++) ft[i]=f[i]-f[i+3*n];
for(int i=0;i<n;i++) gt[i]=g[i]-g[i+3*n];
fwt(v,ft,gt,n);
for(int i=0;i<n;i++)
{
a[3*n+i]=(u[i]-v[i])>>1;
a[n+i]-=a[3*n+i];
}
for(int i=0;i<n;i++) a[i]=(u[i]+v[i])>>1;
for(int i=0;i<n;i++) ft[i]=f[n+i];
for(int i=0;i<n;i++) gt[i]=g[n+i];
fwt(u,ft,gt,n);
for(int i=0;i<n;i++) a[i]+=u[i];
for(int i=0;i<n;i++) ft[i]=f[2*n+i];
for(int i=0;i<n;i++) gt[i]=g[2*n+i];
fwt(v,ft,gt,n);
for(int i=0;i<n;i++)
{
a[i]+=v[i];
a[2*n+i]-=a[i];
}
}
int n,m;
ll a[maxn],f[maxn],g[maxn];
ll ans[maxn];
int main()
{
freopen("space.in","r",stdin);
freopen("space.out","w",stdout);
read(n); read(m);
for(int i=1;i<=n;i++)
{
int now=0;
for(int j=0;j<m;j++)
{
int x; read(x);
now|=x<<j*2;
}
f[now]++,g[now]++;
}
fwt(a,f,g,1<<m*2);
for(int i=0;i<(1<<m*2);i++)
{
int dis=0;
for(int j=0;j<m;j++) dis+=i>>j*2&3;
ans[dis]+=a[i];
}
for(int i=0;i<=3*m;i++) printf("%lld%c",ans[i],i!=3*m?' ':'\n');
return 0;
}
B.
考虑对每条边分开计算他贡献的概率
模拟Kruskal的过程,从小到大处理边
令
f[i][mask]
f
[
i
]
[
m
a
s
k
]
表示处理到第
i
i
条边,前条边已经确定了是否存在,
mask
m
a
s
k
是一个极大联通块的概率(这里极大联通块的定义为不能在往点集中添加点使得这些点仍然两辆联通)
对于第
i
i
条边,他贡献的概率是
f[i][j](j包含u不包含v)
f
[
i
]
[
j
]
(
j
包
含
u
不
包
含
v
)
考虑转移
若
j
j
不包含,
f[i+1][j]=f[i][j]
f
[
i
+
1
]
[
j
]
=
f
[
i
]
[
j
]
,这条边是否存在与这个联通块无关
若
j
j
包含中的一个,这条边一定不存在,
f[i+1][j]=f[i][j]∗pi
f
[
i
+
1
]
[
j
]
=
f
[
i
]
[
j
]
∗
p
i
若
j
j
同时包含,枚举
j
j
的子集,
f[i+1][j]=f[i][j]+∑k,k包含u不包含v f[i][k]∗f[i][j−k]∗(1−pi)/t
f
[
i
+
1
]
[
j
]
=
f
[
i
]
[
j
]
+
∑
k
,
k
包
含
u
不
包
含
v
f
[
i
]
[
k
]
∗
f
[
i
]
[
j
−
k
]
∗
(
1
−
p
i
)
/
t
其中
t
t
表示共有的边不存在的概率,因为这部分概率会被
k,j−k
k
,
j
−
k
各计算一次,要除回来,计算
t
t
可以用表示点集
mask
m
a
s
k
内所有边不存在的概率,
t=h[j]h[k]∗h[j−k]
t
=
h
[
j
]
h
[
k
]
∗
h
[
j
−
k
]
,因为我比较zz,考试时没想到这个,写了个分段打表的东西,带个
m20
m
20
的系数计算
t
t
,但因为状态总数不是满的所以也能跑过去
code:
#include<set>
#include<map>
#include<deque>
#include<queue>
#include<stack>
#include<cmath>
#include<ctime>
#include<bitset>
#include<string>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<climits>
#include<complex>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn = 14;
const int maxm = 95;
const int mask2= 1<<14;
const int mask3= 4782969;
const double eps = 1e-15;
int to3[mask2];
int n,m,al,el=(1<<20)-1;
struct edge
{
int x,y,c;
double pi;
friend inline bool operator <(const edge x,const edge y){return x.c<y.c;}
}e[maxm];
double pe[5][1<<20];
struct data
{
int ei[5];
void add(int i) { ei[i/20]|=1<<(i%20); }
friend inline data operator &(const data &a,const data &b)
{
data re;
for(int i=0;i<5;i++) re.ei[i]=a.ei[i]&b.ei[i];
return re;
}
double cal()
{
double r=1.0;
for(int i=0;i<5;i++) r=r*pe[i][ei[i]];
return r;
}
}ei[mask2],now;
double ans;
double g[maxm][mask2];
int main()
{
freopen("bridge.in","r",stdin);
freopen("bridge.out","w",stdout);
scanf("%d%d",&n,&m); al=(1<<n)-1;
for(int i=0;i<=al;i++) to3[i]=to3[i>>1]*3+(i&1);
for(int i=0;i<5;i++) for(int j=0;j<=el;j++) pe[i][j]=1.0;
for(int i=0;i<m;i++)
{
scanf("%d%d%d",&e[i].x,&e[i].y,&e[i].c); e[i].x--,e[i].y--;
scanf("%lf",&e[i].pi);
if(e[i].x==e[i].y) { i--;m--;continue; }
}
sort(e,e+m);
for(int i=0;i<m;i++)
{
for(int S=al^(1<<e[i].x),s=S;;s=(s-1)&S)
{
ei[s|1<<e[i].x].add(i);
if(!s)break;
}
for(int S=al^(1<<e[i].y),s=S;;s=(s-1)&S)
{
ei[s|1<<e[i].y].add(i);
if(!s)break;
}
for(int t=i/20,j=i%20,S=el^(1<<j),s=S;;s=(s-1)&S)
{
pe[t][s|1<<j]*=e[i].pi;
if(!s)break;
}
}
for(int i=0;i<n;i++) g[0][1<<i]=1.0;
for(int i=0;i<m;i++)
{
for(int j=0;j<=al;j++) if((j>>e[i].x&1)&&!(j>>e[i].y&1)&&g[i][j]>eps)
ans+=g[i][j]*(1-e[i].pi)*e[i].c;
for(int j=0;j<=al;j++)
{
if(!(j>>e[i].x&1)&&!(j>>e[i].y&1)) g[i+1][j]=g[i][j];
else if((j>>e[i].x&1)^(j>>e[i].y&1)) g[i+1][j]=g[i][j]*e[i].pi;
else
{
g[i+1][j]=g[i][j];
for(int S=j^(1<<e[i].x)^(1<<e[i].y),s=S;;s=(s-1)&S)
{
int x=s|1<<e[i].x,y=(S^s)|1<<e[i].y;
if(g[i][x]>eps&&g[i][y]>eps)
{
double tc=(ei[x]&ei[y]&now).cal();
if(tc>eps) g[i+1][j]+=g[i][x]*g[i][y]/tc*(1-e[i].pi);
}
if(!s)break;
}
}
}
now.add(i);
}
printf("%.6lf\n",ans);
return 0;
}
C.
朴素的想法是直接跑费用流
考虑一些优化
源连白点流量1,黑点连汇流量1,把黑白点排列在数轴上,相邻点连边流量inf费用距离,这样可以把边数降到
因为费用流的特性,每次流必定是在剩下的点中选一对相邻的黑白点流,然后去掉这两个点。我们可以用数据结构代替费用流去跑这个东西,相邻的每对点维护两个vector表示两个方向反向边的费用和,按流量从大到小排,把不同色的相邻点放堆里,每次取出堆顶跑,然后将他和相邻的两条边合并,合并用vector的启发式合并
复杂度 O(nlogn) O ( n l o g n )
code:
#include<set>
#include<map>
#include<deque>
#include<queue>
#include<stack>
#include<cmath>
#include<ctime>
#include<bitset>
#include<string>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<climits>
#include<complex>
#include<iostream>
#include<algorithm>
#define ll long long
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define SZ(x) ((int)x.size())
#define End(x) (x.begin()+SZ(x)-1)
using namespace std;
inline void read(int &x)
{
int f=1; char c;
while(!((c=getchar())>='0'&&c<='9')) if(c=='-') f=-1;
x=c-'0';
while((c=getchar())>='0'&&c<='9') (x*=10)+=c-'0';
if(f==-1) x=-x;
}
const int maxn = 410000;
ll ans;
int n,m,K;
pair<int,int>xi[maxn];
int dis[maxn];
struct List
{
int fa[maxn];
void init(){for(int i=1;i<=n;i++) fa[i]=i;}
int findfa(const int x){return fa[x]==x?x:fa[x]=findfa(fa[x]);}
}L,R;
int vi[maxn][2],nowd[maxn];
vector<int>V[maxn<<1];
int qend(int x,int dir)
{
int ix=vi[x][dir];
return SZ(V[ix])?(*End(V[ix])):0;
}
void Run(int x,int dir)
{
int ix=vi[x][dir],ix2=vi[x][!dir];
int o=0;
if(SZ(V[ix])) o=(*End(V[ix])),V[ix].erase(End(V[ix]));
ans+=nowd[x]-o*2;
V[ix2].pb(nowd[x]-o);
}
void merge(int x,int y)
{
nowd[x]+=nowd[y];
int sx,sy;
int &ix0=vi[x][0],&iy0=vi[y][0]; if(SZ(V[ix0])<SZ(V[iy0])) swap(ix0,iy0);
sx=SZ(V[ix0]),sy=SZ(V[iy0]);
for(int i=0;i<sy;i++) V[ix0][sx-(sy-i)]+=V[iy0][i];
int &ix1=vi[x][1],&iy1=vi[y][1]; if(SZ(V[ix1])<SZ(V[iy1])) swap(ix1,iy1);
sx=SZ(V[ix1]),sy=SZ(V[iy1]);
for(int i=0;i<sy;i++) V[ix1][sx-(sy-i)]+=V[iy1][i];
}
set< pair<int,int> >S;
set< pair<int,int> >::iterator it;
int main()
{
freopen("friend.in","r",stdin);
freopen("friend.out","w",stdout);
read(n); read(m); read(K);
for(int i=1;i<=n;i++) read(xi[i].fir),xi[i].sec=0;
for(int i=1;i<=m;i++) read(xi[n+i].fir),xi[n+i].sec=1;
n+=m;
sort(xi+1,xi+n+1);
L.init(); R.init();
for(int i=2;i<=n;i++)
{
vi[i][0]=i,vi[i][1]=n+i;
nowd[i]=xi[i].fir-xi[i-1].fir;
dis[i]=dis[i-1]+nowd[i];
if(xi[i-1].sec^xi[i].sec) S.insert(mp(nowd[i],i));
}
while(K--)
{
it=S.begin(); pair<int,int>temp=(*it); S.erase(it);
int x=temp.sec,y=L.findfa(x-1),dir=!xi[y].sec;
if(R.findfa(x)!=x||!y||nowd[x]-qend(x,dir)*2!=temp.fir) {K++;continue;}
Run(x,dir);
L.fa[y]=y-1,R.fa[y]=y+1;
L.fa[x]=x-1,R.fa[x]=x+1;
if(!L.findfa(y-1)||!R.findfa(x+1)) continue;
merge(x,y);
y=R.findfa(x+1);
merge(y,x),x=y;
y=L.findfa(x-1),dir=!xi[y].sec;
if(xi[x].sec^xi[y].sec) S.insert(mp(nowd[x]-qend(x,dir)*2,x));
}
printf("%lld\n",ans);
return 0;
}