这题我只想到了线段树和四边形优化,其实不加四边形优化也能过,不过湖大OJ过不了
其实这题有一个比线段树更好的方法去算区间[i,j]的花费:
设区间[i,j]的总花费为w[i,j],区间[i,j]对j造成的花费为ss[i,j],i对j的花费为c[i,j]
则
w[i,j]= w[i,j-1]+ s[i,j]
s[i,j]= s[i+1, j]+ c[i,j]
线段树+四边形优化代码O(n^2logn+nk):
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define maxn 1010
struct node
{
int l, r, sum;
}a[4*maxn];
int A[maxn], t[maxn][maxn];
int w[maxn][maxn], dp[maxn][maxn], xx[maxn][maxn];
void Build(int l, int r, int n)
{
a[n].l= l;
a[n].r= r;
a[n].sum= 0;
if(a[n].l== a[n].r)
return;
int mid= (a[n].l+ a[n].r)/ 2;
Build(l, mid, 2*n);
Build(mid+1, r, 2*n+1);
}
void add(int n, int v)
{
if(a[n].l== v && a[n].r== v)
{
a[n].sum++;
return;
}
int mid= (a[n].l + a[n].r)/ 2;
if(v<= mid)
add(2*n, v);
else
add(2*n+1, v);
a[n].sum= a[2*n].sum+ a[2*n+1].sum;
}
int query(int l, int r, int n)
{
if(a[n].l== l && a[n].r== r)
return a[n].sum;
int mid= (a[n].l + a[n].r)/ 2;
if(r<= mid)
return query(l, r, 2*n);
else if(l> mid)
return query(l, r, 2*n+1);
else
return query(l, mid, 2*n)+ query(mid+1, r, 2*n+1);
}
int main()
{
int n, s, k;
while(scanf("%d %d %d",&n,&s,&k)!=EOF)
{
Build(1, s, 1);
memset(w, 0, sizeof w);
memset(xx, 0, sizeof xx);
for(int i= 1; i<= n; i++)
{
scanf("%d",&A[i]);
add(1, A[i]);
for(int j= 1; j< A[i]; j++)
xx[j][A[i]]+= query(j, A[i]-1, 1);
}
for(int i= 1; i<= s; i++)
for(int j= i+1; j<= s; j++)
w[i][j]= w[i][j-1]+ xx[i][j];
memset(dp, -1, sizeof dp);
for(int i= 1; i<= s; i++)
{
dp[i][1]= w[1][i];
t[i][1]= 1;
t[s+1][i]= s;
}
for(int j= 2; j<= k; j++)
for(int i= s; i>= j; i--)
{
int aa= t[i][j-1];
int bb= t[i+1][j];
for(int u= aa; u<= bb; u++)
{
int temp= dp[u][j-1]+ w[u+1][i];
if(dp[i][j]== -1 || temp< dp[i][j])
{
dp[i][j]= temp;
t[i][j]= u;
}
}
}
printf("%d\n",dp[s][k]);
}
return 0;
}
好的方法O(n^2+nk):
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define maxn 1010
int w[maxn][maxn], ss[maxn][maxn], dp[maxn][maxn];
int t[maxn][maxn], c[maxn][maxn];
int A[maxn];
int main()
{
int n, k, s;
while(scanf("%d %d %d",&n,&s,&k)!=EOF)
{
memset(c, 0, sizeof c);// 表示i对j的花费
memset(ss, 0, sizeof ss);//表示[i,j]排对j的花费
memset(w, 0, sizeof w);//表示区间[i,j]的总花费
for(int i= 1; i<= n; i++)
{
scanf("%d",&A[i]);
for(int j= 1; j<= i; j++)
if(A[j]< A[i])
c[A[j]][A[i]]++;
}
for(int j= 1; j<= s; j++)
for(int i= j; i> 0; i--)
ss[i][j]= ss[i+1][j]+ c[i][j];
for(int i= 1; i<= s; i++)
for(int j= i; j<= s; j++)
w[i][j]= w[i][j-1]+ ss[i][j];
memset(dp, -1, sizeof dp);
for(int i= 1; i<= s; i++)
{
dp[i][1]= w[1][i];
t[i][1]= 1;
t[s+1][i]= s;
}
for(int j= 2; j<= k; j++)
for(int i= s; i>= j; i--)
{
int aa= t[i][j-1];
int bb= t[i+1][j];
for(int u= aa; u<= bb; u++)
{
int temp= dp[u][j-1]+ w[u+1][i];
if(dp[i][j]== -1 || temp< dp[i][j])
{
dp[i][j]= temp;
t[i][j]= u;
}
}
}
printf("%d\n",dp[s][k]);
}
return 0;
}