说在前面
这个题真的是,卡的me心累
复杂度只要劣一点都过不去,T了一页最后学着别人的写法过了
题目
题目大意
给出一张
n
n
个点 条边的图,点有点权
现在对于图中的每个三元环,定义其价值为
max(ai,aj,ak)
max
(
a
i
,
a
j
,
a
k
)
求出所有三元环的总价值
输入输出格式
输入格式:
第一行两个整数
n
n
,,含义如题
接下来一行
n
n
个整数,第 个整数表示第
i
i
个点的点权
接下来
m
m
行,每行两个整数描述一条边
输出格式:
输出一行一个整数表示答案
解法
首先,我们需要感性的理解一下三元环的总个数,它大概不会很多
反正me也不会证…然后考虑暴力枚举
我们把点分成两类,一类点的度数小于,称之为「小点」;另一类度数大于 m−−√ m ,称之为「大点」
那么对于「小点」,我们暴力枚举与他相连的另外两个点 u,v u , v ,Hash判断 u,v u , v 是否相连,如果相连则累计贡献。复杂度为 ∑degi2 ∑ d e g i 2 (当这是一个 m−−√ m 个点的完全图时,复杂度为 mm−−√ m m 。如果拆边加点,就是一个类似均值不等式的玩意了,所以上限应该就是 mm−−√ m m )
然后对于「大点」因为其不超过 m−−√ m 个,我们直接 k3 k 3 枚举三个点,Hash判断它们之间是否有边,复杂度也是 mm−−√ m m 的
然后就可以过这道题了?并不
我们这样会把一些三元环枚举多次,即使使用剪枝,仍然会有多余的枚举次数
所以我们把双向边全部改成:权值大的 向 权值小的 连边。然后按照权值升序去枚举点
u
u
,再枚举与之相连的点 :
- 如果 v v 的度数小于,那么我们就枚举与 v v 相连的点 ,check p p 和 是否相连。由于边单向,每条边导致每个点被枚举一次,复杂度 mm−−√ m m
- 如果 v v 的度数大于,那就枚举与 u u 相邻的点 ,判断 q q 是否和 相连。复杂度分析类似,如果 u u 度数小于根号 和上述类似。如果 的度数也大于根号,那么连接 u,v u , v 的边数肯定很少,最后应该也能是 mm−−√ m m 级别的,但是me不会证。
这样写,因为所有三元组保证仅枚举一次,所以可以通过此题
下main是代码
#include <set>
#include <cmath>
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
int N , M , sa[100005] , rk[100005] , val[100005] , lim ;
vector<int> E[100005] ;
set<int> mp[100005] ;
bool cmp( const int &a , const int &b ){
return val[a] < val[b] || ( val[a] == val[b] && a < b ) ;
}
int tim_c , ape[100005] ;
void solve(){
long long ans = 0 ;
for( int i = 3 ; i <= N ; i ++ ){
int u = sa[i] , usiz = E[u].size() ;
tim_c ++ ;
for( int j = 0 ; j < usiz ; j ++ )
ape[ E[u][j] ] = tim_c ;
for( int j = 0 ; j < usiz ; j ++ ){
int v = E[u][j] , vsiz = E[v].size() ;
if( vsiz <= lim ){
for( int k = 0 ; k < vsiz ; k ++ )
if( ape[ E[v][k] ] == tim_c ) ans += val[u] ;
} else for( int k = 0 ; k < usiz ; k ++ )
if( mp[v].count( E[u][k] ) ) ans += val[u] ;
mp[u].insert( v ) ;
}
} printf( "%lld\n" , ans ) ;
}
int main(){
scanf( "%d%d" , &N , &M ) , lim = sqrt( M ) ;
for( int i = 1 ; i <= N ; i ++ )
scanf( "%d" , &val[i] ) , sa[i] = i ;
sort( sa + 1 , sa + N + 1 , cmp ) ;
for( int i = 1 ; i <= N ; i ++ ) rk[ sa[i] ] = i ;
for( int i = 1 , u , v ; i <= M ; i ++ ){
scanf( "%d%d" , &u , &v ) ;
if( rk[u] < rk[v] ) swap( u , v ) ;
E[u].push_back( v ) ;
} solve() ;
}
再附一份me TLE的代码
#include <cmath>
#include <ctime>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
typedef unsigned long long ull ;
int N , M , val[100005] , deg[100005] , tp , head[100005] ;
ull Base1 = 23333 , Base2 = 10007 ;
ull B1[100005] , B2[100005] ;
struct Path{
int pre , to ;
}p[500005] ;
struct Hash_Table{
int head[1048576] , tp , pre[500005] ;
ull h1[500005] , h2[500005] ;
void Insert( ull t1 ){
for( int i = head[t1 & 1048575] ; i ; i = pre[i] )
if( h1[i] == t1 ) return ;
pre[++tp] = head[t1 & 1048575] , head[t1 & 1048575] = tp ;
h1[tp] = t1 ;
}
bool find( ull t1 ){
for( int i = head[t1 & 1048575] ; i ; i = pre[i] )
if( h1[i] == t1 ) return true ;
return false ;
}
} Hs ;
void In( int t1 , int t2 ){
p[++tp] = ( Path ){ head[t1] , t2 } ; head[t1] = tp ;
p[++tp] = ( Path ){ head[t2] , t1 } ; head[t2] = tp ;
}
void init(){
B1[0] = B2[0] = 1 ;
for( int i = 1 ; i <= N ; i ++ )
B1[i] = B1[i-1] * Base1 , B2[i] = B2[i-1] * Base2 ;
}
ull Hs1( int u , int v ){
return B1[u] * u + v * B2[v] ;
}
int sma[100005] , big[100005] , ts , tb , cnt ;
bool sm[100005] ;
void solve(){
int llim = sqrt( M ) ;
for( int i = 1 ; i <= N ; i ++ )
if( deg[i] <= llim ) sma[++ts] = i , sm[i] = true ;
else big[++tb] = i ;
long long ans = 0 ;
for( int k = 1 ; k <= ts ; k ++ ){
int u = sma[k] ;
for( int i = head[u] ; i ; i = p[i].pre ){
int x = p[i].to , y ;
if( sm[x] && x > u ) continue ;
for( int j = p[i].pre ; j ; j = p[j].pre ){
y = p[j].to ;
if( sm[y] && y > u ) continue ;
if( Hs.find( Hs1( x , y ) ) )
ans += max( val[u] , max( val[x] , val[y] ) ) ;
}
}
}
for( int i = 1 ; i <= tb ; i ++ ){
int u = big[i] , x , y ;
for( int j = i + 1 ; j <= tb ; j ++ ){
x = big[j] ;
if( !Hs.find( Hs1( u , x ) ) ) continue ;
for( int k = j + 1 ; k <= tb ; k ++ ){
y = big[k] ;
if( !Hs.find( Hs1( u , y ) ) || !Hs.find( Hs1( x , y ) ) ) continue ;
ans += max( val[u] , max( val[x] , val[y] ) ) ;
}
}
} printf( "%lld\n" , ans ) ;
}
int main(){
scanf( "%d%d" , &N , &M ) , init() ;
for( int i = 1 ; i <= N ; i ++ )
scanf( "%d" , &val[i] ) ;
for( int i = 1 , u , v ; i <= M ; i ++ ){
scanf( "%d%d" , &u , &v ) ;
deg[u] ++ , deg[v] ++ , In( u , v ) ;
Hs.Insert( Hs1( u , v ) ) ;
Hs.Insert( Hs1( v , u ) ) ;
} solve() ;
}