传送门:【ZOJ】3942 Substring Counting
赛场上写了两个多小时没有过,回来后过了好久才肯补。结果写了4个多小时,还写了trie来对拍,调到电脑没电还没调好,第二天终于找到最后一个细节错,艰难的AC了。各种漏细节,没救了。(希哥说他这题定义是简单题,mdzz!)
总之就分三种情况模拟,两边是00,两边是11,两边是01或10(01或10只要reverse下数组就能用同一种方法计算)。
大致思路我不说了,我用了
O(N2logN)
的做法,实际上把我for换成双指针,然后仔细思考下可以优化到
O(NlogN)
,不过很难写就是了。
my code:
#include <bits/stdc++.h>
using namespace std ;
typedef long long LL ;
const int MAXN = 2005 ;
int s[MAXN] , p[MAXN] ;
int fail[MAXN] ;
int v[MAXN][MAXN] ;
int c[MAXN] ;
pair < int , int > L[MAXN] , R[MAXN] ;
map < int , int > mp ;
int nl , nr ;
int n , m ;
void init () {
for ( int i = 1 ; i <= n ; ++ i ) {
for ( int j = 1 ; j <= n ; ++ j ) {
v[i][j] = 0 ;
}
}
}
void getfail ( int p[] , int n ) {
fail[1] = 0 ;
for ( int i = 2 , j = 0 ; i <= n ; ++ i ) {
while ( j && p[i] != p[j + 1] ) j = fail[j] ;
if ( p[i] == p[j + 1] ) ++ j ;
fail[i] = j ;
}
}
int kmp ( int p[] , int m , int s[] , int n , int num ) {
mp.clear () ;
for ( int i = 1 , j = 0 ; i <= n ; ++ i ) {
while ( j && s[i] != p[j + 1] ) j = fail[j] ;
if ( s[i] == p[j + 1] ) {
++ j ;
if ( j == m ) {
int x = s[i - m] , y = s[i + 1] ;
if ( x && y && i % 2 == 0 &&x + y >=num ) {
if ( x >= num ) {
mp[1] ++ ;
mp[min ( y , num - 1 ) + 1] -- ;
} else {
mp[num - x] ++ ;
mp[min ( y , num - 1 ) + 1] -- ;
}
}
v[i - m + 1][i] = 1 ;
j = fail[j] ;
}
} else j = fail[j] ;
}
int val = 0 , pre = 0 , pos = 0 ;
for ( map < int , int > :: iterator it = mp.begin () ; it != mp.end () ; ++ it ) {
if ( pre > 0 ) val += it->first - pos ;
pre += it->second ;
pos = it->first ;
}
return val ;
}
LL kmp2 ( int p[] , int m , int s[] , int n ) {
LL ans = 0 ;
nl = nr = 0 ;
for ( int i = 1 , j = 0 ; i <= n ; ++ i ) {
while ( j && s[i] != p[j + 1] ) j = fail[j] ;
if ( s[i] == p[j + 1] ) {
++ j ;
if ( j == m ) {
if ( i % 2 ) L[++ nl] = make_pair ( s[i - m] , s[i + 1] ) ;
v[i - m + 1][i] = 1 ;
j = fail[j] ;
}
} else j = fail[j] ;
}
sort ( L + 1 , L + nl + 1 ) ;
R[0] = make_pair ( 0 , 0 ) ;
for ( int i = 1 ; i <= nl ; ++ i ) {
while ( nr && L[i].second >= R[nr].second ) -- nr ;
R[++ nr] = L[i] ;
}
for ( int i = 1 ; i <= nr ; ++ i ) {
ans += 1LL * ( R[i].first - R[i - 1].first ) * R[i].second ;
}
return ans ;
}
int kmp3 ( int p[] , int m , int s[] , int n , int x , int y , int num ) {
int val = 0 ;
for ( int i = 1 , j = 0 ; i <= n ; ++ i ) {
while ( j && s[i] != p[j + 1] ) j = fail[j] ;
if ( s[i] == p[j + 1] ) {
++ j ;
if ( j == m ) {
if ( i % 2 == 0 && s[i + 1] >= num ) val = max ( val , s[i - m] ) ;
v[i - m + 1][i] = 1 ;
j = fail[j] ;
}
} else j = fail[j] ;
}
return val ;
}
void solve () {
int maxv = 0 ;
scanf ( "%d%d" , &n , &m ) ;
for ( int i = 1 ; i <= n ; ++ i ) {
scanf ( "%d" , &s[i] ) ;
if ( i % 2 == 0 ) maxv = max ( maxv , s[i] ) ;
}
if ( !m ) {
printf ( "%d\n" , maxv ) ;
return ;
}
if ( n % 2 == 0 ) s[++ n] = 0 ;
s[n + 1] = 0 ;
LL ans = 0 ;
for( int i = 1 ; i <= n ; i += 2 ) {
if ( s[i] >= m ) {
++ ans ;
break ;
}
}
init () ;
for ( int i = 2 ; i <= n ; i += 2 ) {//00
int zero = 0 ;
for ( int j = i ; j <= n ; j += 2 ) {
if ( zero >= m ) break ;
if ( s[i - 1] + s[j + 1] >= m - zero && !v[i][j] ) {
int l = j - i + 1 ;
for ( int k = 1 ; k <= l ; ++ k ) {
p[k] = s[i + k - 1] ;
}
getfail ( p , l ) ;
int tmp = kmp ( p , l , s , n , m - zero ) ;
ans += tmp ;
}
zero += s[j + 1] ;
}
}
init () ;
for ( int i = 1 ; i <= n ; i += 2 ) {//11
int zero = 0 ;
for ( int j = i ; j <= n ; j += 2 ) {
zero += s[j] ;
if ( zero > m ) break ;
if ( zero == m ) {
if ( !v[i][j] ) {
int l = j - i + 1 ;
for ( int k = 1 ; k <= l ; ++ k ) {
p[k] = s[i + k - 1] ;
}
getfail ( p , l ) ;
ans += kmp2 ( p , l , s , n ) ;
}
break ;
}
}
}
init () ;
for ( int i = 1 ; i <= n ; i += 2 ) {//10
int zero = 0 ;
for ( int j = i ; j <= n ; j += 2 ) {
zero += s[j] ;
if ( zero >= m ) {
if ( j != i && !v[i][j - 1] ) {
int l = j - i ;
for ( int k = 1 ; k <= l ; ++ k ) {
p[k] = s[i + k - 1] ;
}
getfail ( p , l ) ;
ans += kmp3 ( p , l , s , n , i , j , m - zero + s[j] ) ;
}
break ;
}
}
}
for ( int i = 1 ; i <= n / 2 ; ++ i ) {
swap ( s[i] , s[n - i + 1] ) ;
}
init () ;
for ( int i = 1 ; i <= n ; i += 2 ) {//01
int zero = 0 ;
for ( int j = i ; j <= n ; j += 2 ) {
zero += s[j] ;
if ( zero >= m ) {
if ( j != i && !v[i][j - 1] ) {
int l = j - i ;
for ( int k = 1 ; k <= l ; ++ k ) {
p[k] = s[i + k - 1] ;
}
getfail ( p , l ) ;
int tmp = kmp3 ( p , l , s , n , i , j , m - zero + s[j] ) ;
ans += tmp ;
}
break ;
}
}
}
int Ml = 0 , Mr = 0 ;
for ( int i = 1 ; i <= n ; i += 2 ) {
if ( s[i] >= m ) Ml = max ( Ml , s[i - 1] ) ;
if ( s[i] >= m ) Mr = max ( Mr , s[i + 1] ) ;
}
ans += Ml + Mr ;
printf ( "%lld\n" , ans ) ;
}
int main () {
int T ;
scanf ( "%d" , &T ) ;
for ( int i = 1 ; i <= T ; ++ i ) {
solve () ;
}
return 0 ;
}
/*
10
10 6
5 4 1 4 1 1 5 5 1 1
5 2
2 2 1 2 1
4 2
1 2 1 2
5 5
5 1 2 1 1
ans:
49
8
3
6
*/