问题描述
如果我们进行过投资行为,都知道一个简单的原则:低价买进,高价卖出。例如下图所示的曲线,股票价格是一直在波动的,我们想知道哪天买入,哪天卖出,收益最高。如下图所示,是一个很好的示例:
这个问题的实质是要求出一个最大子数组,使得首尾两个元素的差值最大。
思路
为简化问题,我们把原始数组[100, 113, 110, 85, 105, 102, 86, 63, 81, 101, 94, 106, 101, 79, 94, 90,97]
中的相邻两个值相减,得到一个差值数组A=[13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7]
,现在的问题转化成:求数组A
的一个子数组,这个子数组所有元素的和最大。
一个基本的思路是使用暴力算法,计算出A
中任意两个元素之间的和,再求出这些和中的最大值。这里介绍一种更优雅的算法,这个算法应用了分治策略的思想。
所谓分治策略,是指把一个原始问题分解成若干个足够简单的子问题,解决完这些子问题后,再通过合并的方法,逐层回溯解决原始问题。应用分治策略,往往会使用到递归的方法,递归方法往往需要进行如下几个步骤:
- 分解步骤将问题划分为一些子问题,子问题的形式与原问题一样,只是规模更小。
- 解决步骤递归地求解出子问题。如果子问题的规模足够小,则停止递归,直接求解。
- 合并步骤将子问题的解组合成原问题的解。
回到这个问题,我们可以这样来分析问题:
假设要寻找A[low..high]
的最大子数组(下文简称为maxSubArray),那么按照分治策略的思想,首先将其拆分成两个子数组。假设从中间位置mid
作为分割点,那么存在如下几种情况:
- maxSubArray完全位于子数组
A[low..mid]
中,因此low<=i<=j<=mid
- maxSubArray完全位于子数组
A[mid+1..high]
中,因此mid<i<=j<=high
- maxSubArray完全跨越了中点,因此
low<=i<=mid<j<=high
对于上述第1点和第2点,对应于问题的拆分,我们可以将问题拆分到足够简单的地步,例如,拆分到子数组只有一个元素,这时,该子数组本身就是一个maxSubArray。第3点则对应上述合并的步骤。我们要设计一个合并算法,以任意一个子数组的中点为界,找出包含该中点的最大值(详见下面代码中的findMaxCrossingSubArray
函数),再和左右两边数组的maxSubArray进行对比,三者中的最大值就是对应子数组的maxSubArray。
代码
最新代码请参考本人github
# find MaxSubArray of subArray where mid element sits in
def findMaxCrossingSubArray(A, low, mid, high):
maxLeft = 0
leftSum = float('-inf')
sum = 0
for i in range(mid, -1, -1):
sum = sum + A[i]
if sum > leftSum:
leftSum = sum
maxLeft = i
maxRight = 0
rightSum = float('-inf')
sum = 0
for i in range(mid+1, high + 1, 1):
sum = sum + A[i]
if sum > rightSum:
rightSum = sum
maxRight = i
return (maxLeft, maxRight, leftSum + rightSum)
# recursive method to find maxSubArray
def findMaxSubArray(A, low, high):
print("low: ", low, "high: ", high)
# recursive stop contidion
if low == high:
return(low, high, A[low])
else:
mid = (low + high)//2
# subArray sits in left subarray
# or right subarray
# or cross them
# find maxium of them as the final result
leftLow, leftHigh, leftSum = findMaxSubArray(A, low, mid)
righLow, righHigh, rightSum = findMaxSubArray(A, mid + 1, high)
crossLow, crossHigh, crossSum = findMaxCrossingSubArray(
A, low, mid, high)
if leftSum >= rightSum and leftSum >= crossSum:
return(leftLow, leftHigh, leftSum)
elif rightSum >= leftSum and rightSum >= crossSum:
return(righLow, righHigh, rightSum)
else:
return(crossLow, crossHigh, crossSum)
if __name__ == '__main__':
A = [13, -3, -25, 20, -3, -16, -23, 18, 20, -7, 12, -5, -22, 15, -4, 7]
maxSubArray = findMaxSubArray(A, 0, len(A) - 1)
print("maxSubArrayLeftIndex: ", maxSubArray[0])
print("maxSubArrayRightIndex: ", maxSubArray[1])
print("maxSubArraySum: ", maxSubArray[2])