Description
给你一个长度为 n n n 的字符串,和一个正整数 m m m ,其中 m m m 能整除 n n n ,现有一算法就是把整个字符串分成 n m \dfrac{n}{m} mn 段,然后每一段按照一个相同的 1 1 1 ~ m m m 的排列方式进行重新排列,然后把这 n m \dfrac{n}{m} mn 段合并,算法把字符串中连续相同的字符合并成一个字符,得到一个新字符串。
请你求出新的字符串可能的最短长度。
1 ≤ m ≤ 16 , 1 ≤ n ≤ 5 ≤ 1 0 4 1 \leq m \leq 16,1 \leq n \leq 5 \leq 10^4 1≤m≤16,1≤n≤5≤104 。
时间限制为 4s ,空间限制为 65536 KB 。
Solution
发现 m m m 很小,考虑使用 状压 DP 来解题 。
设 f [ S ] [ i ] [ j ] f[S][i][j] f[S][i][j] 表示当前已经选的数的状态为 S S S ,排列中的第一个数是 i i i , 当前 已经选择的最后一个数是 j j j 的情况下的新字符串可能的最小长度。
但是我们发现这样会空超。
考虑删掉一维,于是我们得到了下面的定义:
设 f [ S ] [ i ] f[S][i] f[S][i] 表示当前已经选的数的状态为 S S S , 当前 已经选择的最后一个数是 i 的情况下的新字符串可能的最小长度。
然后枚举 排列中的第一个数 即可。
这样就不会空超了。
然后我们设 s u m [ i ] [ j ] sum[i][j] sum[i][j] 表示如果在排列中, i i i 在 j j j 的后面一个位置,那么所新产生的长度为 s u m [ i ] [ j ] sum[i][j] sum[i][j] 。
这个是可以用 O ( n m ) O(nm) O(nm) 的时间复杂度预处理出来的。
并且设 s a m e ( x , y ) same(x,y) same(x,y) 表示当排列中的第 m m m 个数为 x x x ,排列中的第一个数为 y y y 时能够消掉的长度为 s a m e ( x , y ) same(x,y) same(x,y) 。
这个是可以用 O ( n m ) O(\dfrac{n}{m}) O(mn) 的时间复杂度求出来的,然后就可以利用它们来转移了。
转移方程如下:
f [ S ∣ 2 j ] [ j ] = { min { f [ S ] [ i ] + s u m [ i ] [ j ] } ( S ∣ 2 j ) ≠ 2 m − 1 min { f [ S ] [ i ] + s u m [ i ] [ j ] + s a m e ( j , h e ) } ( S ∣ 2 j ) ≠ 2 m − 1 f[S|2^j][j]=\begin{cases} \min\{f[S][i]+sum[i][j]\}\quad\quad\quad\quad\quad\quad\quad\quad\; (S|2^j) \not= 2^m-1 \\ \min\{f[S][i]+sum[i][j]+same(j,he)\} \quad\quad (S|2^j) \not= 2^m-1 \end{cases} f[S∣2j][j]={min{f[S][i]+sum[i][j]}(S∣2j)=2m−1min{f[S][i]+sum[i][j]+same(j,he)}(S∣2j)=2m−1
。
其中 h e he he 表示的是排列中的第一个数。
- 时间复杂度
O ( 2 m m 2 + n m ) O(2^mm^2+nm) O(2mm2+nm)
- ,记得注意位运算的优先级!
- 然后这道题目就做完啦。
Code
#include <cstdio>
#include <cstring>
int f[ 65537 ][ 17 ] , sum[ 17 ][ 17 ] ;
char st[ 1000001 ] ;
int ans = 2147483647 , n = 0 , m = 0 ;
int min( int x , int y )
{
return x < y ? x : y ;
}
int same( int x , int y )
{
int da = 0 ;
for(int i = 1 ; i <= n / m - 1 ; i ++ )
{
char tx = st[ ( i - 1 ) * m + x ] ;
char ty = st[ i * m + y ] ;
if( tx == ty )
{
da ++ ;
}
}
return da ;
}
void work( int he )
{
memset( f , 127 / 3 , sizeof( f ) ) ;
f[ 1 << ( he - 1 ) ][ he ] = n / m ;
int ma = ( 1 << m ) - 1 ;
for(int S = 1 ; S <= ma ; S ++ )
{
if( ( S & ( 1 << ( he - 1 ) ) ) == 0 )
{
continue ;
}
for(int i = 1 ; i <= m ; i ++ )
{
if( ( S & ( 1 << ( i - 1 ) ) ) == 0 )
{
continue ;
}
for(int j = 1 ; j <= m ; j ++ )
{
if( ( S & ( 1 << ( j - 1 ) ) ) != 0 )
{
continue ;
}
if( ( S | ( 1 << ( j - 1 ) ) ) != ma )
{
f[ S | ( 1 << ( j - 1 ) ) ][ j ] = min( f[ S | ( 1 << ( j - 1 ) ) ][ j ] , f[ S ][ i ] + sum[ i ][ j ] ) ;
}
else
{
f[ S | ( 1 << ( j - 1 ) ) ][ j ] = min( f[ S | ( 1 << ( j - 1 ) ) ][ j ] , f[ S ][ i ] + sum[ i ][ j ] - same( j , he ) ) ;
}
}
}
}
for(int i = 1 ; i <= m ; i ++ )
{
if( he == i )
{
continue ;
}
ans = min( ans , f[ ma ][ i ] ) ;
}
}
int main()
{
scanf("%d" , &m ) ;
scanf("%s" , st + 1 ) ;
n = strlen( st + 1 ) ;
for(int i = 1 ; i <= m ; i ++ )
{
for(int j = 1 ; j <= m ; j ++ )
{
if( i == j )
{
continue ;
}
for(int k = 1 ; k <= n / m ; k ++ )
{
if( st[ ( k - 1 ) * m + i ] != st[ ( k - 1 ) * m + j ] )
{
sum[ i ][ j ] ++ ;
}
}
}
}
for(int i = 1 ; i <= m ; i ++ )
{
work( i ) ;
}
printf("%d" , ans ) ;
return 0 ;
}