题目描述:有n个车站,第i(1<= i <= n - 1)个车站可以买从i到i+1 , i + 2 , ...,a[i]的火车票,用p(i , j)表示从车站i到车站j最少买多少张车票,问
sum = Σp(i , j) (1 <= i < j <= n)是多少
思路:在车站i处,可以用一张车票到达[i + 1 , a[i]]中的一站,那么,应选择m车站再买一张车票,其中m∈[i + 1 , a[i] ]且a[m]最大,为什么是这样呢?如图:
设m'是不同于m的一个车站,a[m'] < a[m] , 可以看到,假设现在人在位置i,
线段( i , a[m'] ]部分通过两种换乘方式到达所需要的票数一样;
而线段( a[m'] , a[m] ]的部分, 从m换乘的话两张票就能到, 但从m'换乘的话两张票不一定能到;
再看线段( a[m] , n ]部分,假设这个区间中有一点x ,如果从m'换乘最终到此区间需要k张票,那么从m换乘需要的票数小于等于k张票,因为这两种换乘方式相比,大于a[i]的部分,蓝色的包在了红色的里面,所以蓝色能有的买票方式,红的一定能有,但红色能有的买票方式,蓝色的不一定能有;
如此一来,从m处换乘的原因得证。
设dp[i]表示Σp(i , j) (其中i + 1 <= j <= n) ,那么dp[i] = dp[m] + n - i -( a[i] - m ) , 为什么呢?
想象从i+1到n每个位置都对应了一张票:
那么在( i , m ]区间,每个位置p上的票用来从i走到p;
在( m , n ] 区间,每个位置p上的票用来从i走到m,再从m走到p;
但是可以看到,按照上述规则( m , a[i] ] 部分的买票的方式是先从i走到m,再从m走到每个位置, 而事实上,从i走到每个位置只要一张票,因此要减掉 a[i] - m
另外提一句,有在找m的时候,如果遇到多个a[m]相等,找任一个都可以,这是可以从图上看出来的,很多题解上说要找dp[m]最小的是错的,因为所有dp[m]一样大。
#pragma warning(disable:4786)
#pragma comment(linker, "/STACK:102400000,102400000")
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<stack>
#include<queue>
#include<map>
#include<set>
#include<vector>
#include<cmath>
#include<string>
#include<sstream>
#define LL long long
#define FOR(i,f_start,f_end) for(int i=f_start;i<=f_end;++i)
#define mem(a,x) memset(a,x,sizeof(a))
#define lson l,m,x<<1
#define rson m+1,r,x<<1|1
using namespace std;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const double PI = acos(-1.0);
const double eps=1e-6;
const int maxn = 1e5 + 5 ;
int a[maxn];
LL dp[maxn] ;
struct node
{
int v , p ;
}t[4 * maxn];
void pushup(int l , int r , int x)
{
if(t[x<<1].v < t[x<<1|1].v){
t[x].v = t[x<<1|1].v;
t[x].p = t[x<<1|1].p;
}
else if(t[x<<1].v > t[x<<1|1].v){
t[x].v = t[x<<1].v;
t[x].p = t[x<<1].p;
}
else{
t[x].v = t[x<<1].v;
t[x].p = t[x<<1|1].p ;
}
return ;
}
void build(int l , int r , int x)
{
if(l == r ){
t[x].v = a[l] ;
t[x].p = l ;
return ;
}
int m = l + (r - l) / 2 ;
build(lson) ;
build(rson);
pushup(l , r , x) ;
}
node query(int L , int R , int l , int r , int x )
{
if(L == l && R == r){
return t[x];
}
int m = l + ( r - l ) / 2 ;
if(R <= m)
return query(L , R , lson) ;
else if( L > m)
return query(L , R , rson) ;
else{
node ret1 = query(L , m , lson) ;
node ret2 = query( m + 1 , R , rson) ;
if(ret2.v > ret1.v)
return ret2;
else
return ret1 ;
}
}
int main()
{
int n ;
scanf("%d",&n);
for(int i = 1 ; i<= n - 1 ; i++){
scanf("%d",&a[i]);
}
build(1 , n , 1);
LL ans = 1;
dp[n - 1] = 1 ;
for(int i = n - 2 ; i>= 1 ; i--){
node st = query(i + 1 , a[i] , 1 , n , 1 ) ;
int m = st.p ;
dp[i] = dp[m] + n - i - (a[i] - m) ;
ans += dp[i] ;
}
printf("%lld\n",ans);
return 0;
}