WTF交换
题意:假定给出一个包含N个整数的数组A,包含N+1个整数的数组ID,与整数R。其中ID数组中的整数均在区间[1,N-1]中。用下面的算法对A进行Warshall-Turing-Fourier变换(WTF):
sum = 0
for i = 1 to N
index = min{ ID[i], ID[i+1 ] }
sum = sum + A[index ]
将数组A往右循环移动R位
将数组A内所有的数取相反数
for i = 1 to N
index = max{ ID[i], ID[i+1 ] }
index = index + 1
sum = sum + A[index ]
将数组A往右循环移动R位
给出数组A以及整数R,但没有给出数组ID。在对数组A进行了WTF算法后,变量sum的可能出现的最大值数多少?并输出此时的ID数组。对于100%的数据,2<=N<=3000, 1<=R
解法:可以发现一个ID只和上一个ID与下一个ID有关,我们可以考虑dp。设f[i][j]表示第i个ID,填j最大的sum,g[i][j]表示前驱是f[i-1][g[i][j]]。
f [ i ] [ j ] = m a x f [ i − 1 ] [ k ] + a [ m i n ( j , k ) ] − a [ m a x ( j , k ) + 1 ]
这里的a[x]表示已经位移好了的,显然x=((x-(i-1)*r)%n+n)%n,这里a存在a[0..n-1].
当k<=j时,显然
f [ i ] [ j ] = m a x f [ i − 1 ] [ k ] + a [ k ] − a [ j + 1 ]
而f[i-1][k]+a[k]与j无关,可以预先处理最大值。当k>=j时类似。
贴上代码
#include<set>
#include<cmath>
#include<vector>
#include<cstdio>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
#define mo(x) ((x-(i-1)*r)%n+n)%n
#define fo(i,j,k) for(int i=j;i<=k;i++)
#define fd(i,j,k) for(int i=j;i>=k;i--)
using namespace std ;
int const oo=2147483647 ;
int const maxn=4000 ;
typedef struct {int x,y;}note;
inline int get(){
char ch=getchar();while ((ch!='-' )&&((ch<'0' )||(ch>'9' )))ch=getchar();
int u=1 ,v=0 ;if (ch=='-' )u=-1 ;else v=ch-'0' ;ch=getchar();
while ((ch>='0' )&&(ch<='9' ))v=v*10 +ch-'0' ,ch=getchar();
return u*v;
}
inline char getch(){
char ch=getchar();
while ((ch<'A' )||(ch>'Z' ))ch=getchar();
return ch;
}
int n,r,a[maxn+10 ],f[maxn+10 ][maxn+10 ],g[maxn+10 ][maxn+10 ];
inline void scan(){
n=get(),r=get();
fo(i,0 ,n-1 )
a[i]=get();
}
void print(int i,int j){
if (i!=0 )print(i-1 ,g[i][j]);
printf ("%d " ,j+1 );
}
inline void solve(){
memset (f,200 ,sizeof (f));
memset (f[0 ],0 ,sizeof (f[0 ]));
fo(i,1 ,n){
int tmp=-oo,pos;
fo(j,0 ,n-2 ){
if (f[i-1 ][j]+a[mo(j)]>tmp){
tmp=f[i-1 ][j]+a[mo(j)];
pos=j;
}
if (tmp-a[mo(j+1 )]>f[i][j]){
f[i][j]=tmp-a[mo(j+1 )];
g[i][j]=pos;
}
}
tmp=-oo;
fd(j,n-2 ,0 ){
if (f[i-1 ][j]-a[mo(j+1 )]>tmp){
tmp=f[i-1 ][j]-a[mo(j+1 )];
pos=j;
}
if (tmp+a[mo(j)]>f[i][j]){
f[i][j]=tmp+a[mo(j)];
g[i][j]=pos;
}
}
}
int tmp=-oo,pos;
fo(j,0 ,n-2 )
if (f[n][j]>tmp){
tmp=f[n][j];
pos=j;
}
printf ("%d\n" ,tmp);
print(n,pos);
}
int main(){
scan();
solve();
return 0 ;
}