算法导论(2)分治策略

算法导论(2)分治策略

分治策略递归地求解一个问题时,每层递归中实行三个步骤:

  • 分解:将问题划分为一些形式与原问题一致但规模更小的子问题
  • 解决:递归地求解出子问题。如果子问题的规模足够下则通知递归直接求解
  • 合并:将子问题的解合成原问题的解

1.最大子数组问题

假如你获得投资某个公司的机会,这个公司的股票价格是不稳定的,你只能在某个时刻买进股票然后在之后的某个时刻卖出。假如你知道这家公司将来的股票价格。你需要根据这家 公司的股票走势确定什么时候买入股票什么时候卖出股票以实现利益的最大化。毫无疑问最佳策略是在股票价格最低时期买入在价格最高时期卖出,但假如价格最高的时期在价格最低时期之前,这时就没那么容易得出结果了。

这时我们可以暴力求解这个问题,n天中一共有 C n 2 C^2_n Cn2种组合,运行时间为Θ( n 2 n^2 n2)。那我们需要更好的方法的话就要用到分治策略。

首先我们将这个问题转换一下,我们不看每天的价格而是看每天的价格变化,这样问题就变成了寻找总和最大的连续非空子数组。我们称为最大子数组。

分治策略求解

首先理所当然的将原数组划分为两个规模相当的子数组。对于数组A[low…high],我们找到数组中间位置mid,得到两个子数组A[low…mid]和A[mid+1…high]。数组A[low…high]中的任何连续子数组A[i,j]处的位置必然为以下三种情况之一:

  • 完全处于子数组A[low…mid]中
  • 完全处于子数组A[mid+1…high]中
  • 跨越了中点,low≤i≤mid≤j≤high

我们将子数组越分越小,但递归过程中一个重要的一步是求出形如A[i…mid]和A[mid+1,j]中的最大子数组,然后将其合并即可。过程find_max_c_s接受数组A和下标low,mid,high为输入,返回一个下标元组划定跨越中点的最大子数组的边界,并返回最大数组中值的和。

find_max_c_s(A,low,mid,high)

left-sum=-INFTY
sum=0
for i=mid downto low
	sum=sum+A[i]
	if sum>left-sum
		left-sum=sum
		max-left=i
right-sum=-INFTY
sum=0
for j=mid+1 to high
	sum=sum+A[j]
	if sum>right-sum
		right-sum=sum
		max-right=j
return(max-left,max-right,left-sum+right-sum)

由此我们便有了求解最大子数组问题的分治算法的伪代码:

find_max_s(A,low,high)

if high==low
	return (low,high,A[low])
else mid=[(low+high)//2]
	(left-low,left-high,left-sum)=find_max_s(A,low,mid)
    (right-low,right-high,right-sum)=find_max_s(A,mid+1,high)
    (cross-low,cross-high,cross-sum)=find_max_c_s(A,low,mid,high)
    if left-sum>=right-sum and left-sum>=cross-sum
    	return (left-low,left-high,left-sum)
    elseif right-sum>=left-sum and right-sum>=cross-sum
    	return (right-low,right-high,right-sum)
    else return(cross-low,cross-high,cross-sum)

C++版本,写代码的时候由于一个非常低级的小错误结果导致我找了好久的bug,所以做任何事情都要细心才行。

#include<iostream>
#include<vector>
#include<climits>

void find_max_c_s(int a[], int low, int mid, int high, 
	int& ml, int& mr, int& ms) {
	int ls = -INT_MAX;
	int sum = 0;
	for (int i = mid; i >= low; --i) {
		sum += a[i];
		if (sum > ls) {
			ls = sum;
			ml = i;
		}
	}
	int rs = -INT_MAX;
	sum = 0;
	for (int j = mid + 1; j <= high; ++j) {
		sum += a[j];
		if (sum > rs) {
			rs = sum;
			mr = j;
		}
	}
	ms =ls+rs;
}
void find_max_s(int a[], int low, int high,
	int &low_v,int &high_v,int &sum_v) {
	if (high == low) {
		low_v = high_v = high;
		sum_v = a[high];
		return;
	}
	else {
		int mid = (low + high) / 2;
		int ll, lh, ls,rl,rh,rs,cl,ch,cs;
		find_max_s(a, low, mid, ll, lh, ls);
		find_max_s(a, mid + 1, high, rl, rh, rs);
		find_max_c_s(a, low, mid, high, cl, ch, cs);
		if (ls >= rs && ls >= cs) {
			low_v = ll;
			high_v = lh;
			sum_v = ls;
			return;
		}
		else if (rs >= ls && rs >= cs) {
			low_v = rl;
			high_v = rh;
			sum_v = rs;
			return;
		}
		else {
			low_v = cl;
			high_v = ch;
			sum_v = cs;
			return;
		}
	}
}
int main() {
	int a[]= { 12,-3,42,23,-22,-33,23,12,0,-3,-5,-66,-22,1,32,2,-1,2,-3,44,25,-22,-44,2,22,3,-2};
	//int a[] = {1,2,3,4 };
	int low_v, high_v, sum_v;
	find_max_s(a, 0, std::size(a) - 1, low_v, high_v, sum_v);
	std::cout << "(low,high,sum)=(" << low_v << ","
		<< high_v << "," << sum_v << ")" << std::endl;
	//暴力解法,验证结果是否正确
	int l, h, max=-INT_MAX;
	for (int i = 0; i < std::size(a); ++i) {
		int sum = a[i];
		for (int j = i + 1; j < std::size(a); ++j) {
			sum += a[j];
			if (sum > max) {
				max = sum;
				l = i;
				h = j;
			}
		}
	}
	std::cout << l << "---" << h << "--" << max << std::endl;
	return 0;
}

Java版本的和C++版本的很接近,但其中又有一个值得注意的点,Java在定义函数时没法像C++那样,Java没有指针是一个问题,引用也应该没有,没法像C++那样通过引用类型的形参来充当返回值,所以此时需要稍微变通一下,改成返回一个数组,这样就没任何问题,而且和伪代码的意思也更接近。(输入数组是和C++版本的一样的,所以结果和C++版本的结果对比就知道程序的结果是不是正确的,就不需要后面再用暴力求解一编)

public class algo_sort_1 {
    public int[] find_max_c_s(int[] a, int low, int mid, int high){
        int ls = -Integer.MAX_VALUE;
        int[] test=new int [3];
        int sum = 0;
        for (int i = mid; i >= low; --i) {
            sum += a[i];
            if (sum > ls) {
                ls = sum;
                test[0] = i;
            }
        }
        int rs = -Integer.MAX_VALUE;
        sum = 0;
        for (int j = mid + 1; j <= high; ++j) {
            sum += a[j];
            if (sum > rs) {
                rs = sum;
                test[1] = j;
            }
        }
        test[2] =ls+rs;
        return test;
    }
    public int[] find_max_s(int[] a,int low,int high){
        int[] test=new int[3];
        if (high == low) {
            test[0] = high;
            test[1]=high;
            test[2] = a[high];
            return test;
        }
        else {
            int mid =(int) (low + high) / 2;
            int[] lt=find_max_s(a, low, mid);
            int[] rt=find_max_s(a, mid + 1, high);
            int[]ct= find_max_c_s (a, low, mid, high);
            if (lt[2] >= rt[2] && lt[2] >= ct[2]) {
                test=lt.clone();
            }
            else if (rt[2] >= lt[2] && rt[2] >= ct[2]) {
                test=rt.clone();
            }
            else {
                test=ct.clone();
            }
        }
        return test;
    }
    public static void main(String[] args) {
        int []a={ 12,-3,42,23,-22,-33,23,12,0,-3,-5,-66,-22,1,32,2,-1,2,-3,44,25,-22,-44,2,22,3,-2};
        int l=0,h=0,s=0;
        int[] test={0,0,0};
        int[] tt=new algo_sort_1().find_max_s(a,0,a.length-1);
        System.out.println(tt[0]+"--"+tt[1]+"---"+tt[2]);
    }
}

python版本的一些细节倒是和Java类似,python也没有引用和指针,所以输出值以return数组(列表)的形式才能正确的得到结果,而不能像C++那样把返回值放到形参中,但这样倒是差别不大,而且这样可读性更高点。

import sys
def find_mcs(a,low,mid,high):
    test=[0,0,0]
    ls=-sys.maxsize
    sum=0
    for i in range(mid,low-1,-1):
        sum+=a[i]
        if sum>ls:
            ls=sum
            test[0]=i
    rs=-sys.maxsize
    sum=0
    for j in range(mid+1,high+1):
        sum+=a[j]
        if sum>rs:
            rs=sum
            test[1]=j
    test[2]=ls+rs
    return test
def find_max_s(a,low,high):
    if high==low:
        test=[high,high,a[high]]
    else:
        mid=(low+high)//2
        lt=find_max_s(a,low,mid)
        rt=find_max_s(a,mid+1,high)
        ct=find_mcs(a,low,mid,high)
        if lt[2]>=rt[2] and lt[2]>=ct[2]:
            test=lt.copy()
        elif rt[2]>=lt[2] and rt[2]>=ct[2]:
            test=rt.copy()
        else:
            test=ct.copy()
    return test
aa=[  12,-3,42,23,-22,-33,23,12,0,-3,-5,-66,-22,1,32,2,-1,2,-3,44,25,-22,-44,2,22,3,-2]
a=[1,2,3,4]
testt=find_max_s(aa,0,len(aa)-1)
print(testt)

然后这个算法的时间花费为Θ(n lg n),n足够大时是比暴力解法的Θ( n 2 n^2 n2)要更有优势。

2.矩阵乘法的Strassen算法

学过线性代数都知道矩阵相乘的规则为:对于C=A*B
c i j = ∑ i = 1 n a i j ∗ b k j c_{ij}=\sum_{i=1}^na_{ij}*b_{kj} cij=i=1naijbkj
A和B为n×n的矩阵,我们可以通过下面过程来求两个矩阵的乘积:

SQUARE-MATRIX-MULTIPLY(A,B)

n=A.rows
let C be a new n*n matrix
for i= 1 to n
	for j=1 to n
    	c_ij=0
        for k=1:n
            c_ij=c_ij+a_ik*bk_kj
return C

显然上面过程的时间花费为Θ( n 3 n^3 n3)。但我们其实有办法再更短时间完成矩阵乘法:Strassen的著名n×n矩阵相乘的递归算法,其时间花费为Θ( n l g 7 n^{lg7} nlg7),lg7在2.80和2.81之间,即花费时间Θ( n 2.81 n^{2.81} n2.81)。

分治算法

我们首先假定n为2的幂,方便我们能将矩阵划分为更小的子矩阵。

我们的递归分治算法的伪代码如下:

SMMR(A,B)

n=A.rows
let C be a new n*n matrix
if n==1
	c11=a11*b11
else partition A,B,and C as in equation (4,9)
	C11=SMMR(A11,B11)+SMMR(A12,B21)
    C12=SMMR(A11,B12)+SMMR(A12,B22)
    C21=SMMR(A21,B11)+SMMR(A22,B21)
    C22=SMMR(A21,B12)+SMMR(A22,B22)
return C

这个过程还是很简单易懂的。我们可以按照这个算法来实际操作一下,Fortran是可以直接计算两个矩阵的乘积的,但C++不知道能不能,应该是可以的,就算标准库不行其他一些数值计算的库应该有,矩阵乘法本来是很基本的。但我们不关心编程语言能不能直接计算两个矩阵的乘法,说到底对于二维数组的乘法用三层循环来计算也不费事。

C++版本,本身这个程序并不复杂,原理非常简单,但由于C++语言很严谨,对于我们这种矩阵运算实际上实现起来有点繁琐的,但也仅此而已,小心一点其实也不容易出错

#include<iostream>
#include<vector>
using namespace std;

vector<vector<int>>add(vector<vector<int>>a, vector<vector<int>>b) {
	int n = a.size();
	vector<vector<int>>c;
	for (int i = 0; i < n; ++i) {
		vector<int>vc;
		for (int j = 0; j < n; ++j) {
			vc.push_back(a[i][j] + b[i][j]);
		}
		c.push_back(vc);
	}
	return c;
}

vector<vector<int>> smmr(vector < vector<int>>  a, vector<vector<int>>  b) {
	int n = a.size();
	vector<vector<int>>c;
	if (n == 1) {
		vector<int>vc;
		vc.push_back(a[0][0] * b[0][0]);
		c.push_back(vc);
	}
	else {
		//矩阵分块过程
		vector<vector<int>>c11, c12, c21, c22, a11, a12, a21, a22,
			b11, b12, b21, b22;
		for (int i = 0; i < n / 2; ++i) {
			vector<int>va11,va12,va21,va22, vb11,vb12,vb21,vb22;
			for (int j = 0; j < n / 2; ++j) {
				va11.push_back(a[i][j]);
				va12.push_back(a[i][j + n / 2]);
				va21.push_back(a[i + n / 2][j]);
				va22.push_back(a[i + n / 2][j + n / 2]);
				vb11.push_back(b[i][j]);
				vb12.push_back(b[i][j + n / 2]);
				vb21.push_back(b[i + n / 2][j]);
				vb22.push_back(b[i + n / 2][j + n / 2]);
			}
			a11.push_back(va11);
			a12.push_back(va12);
			a21.push_back(va21);
			a22.push_back(va22);
			b11.push_back(vb11);
			b12.push_back(vb12);
			b21.push_back(vb21);
			b22.push_back(vb22);
		}
		c11 = add(smmr(a11, b11), smmr(a12, b21));
		c12 = add(smmr(a11, b12), smmr(a12, b22));
		c21 = add(smmr(a21, b11), smmr(a22, b21));
		c22 = add(smmr(a21, b12), smmr(a22, b22));
		for (int i = 0; i < n; ++i) {
			vector<int>vc;
			for (int j = 0; j < n; ++j) {
				if (i < n / 2 && j < n / 2) {
					vc.push_back(c11[i][j]);
				}
				else if (i < n / 2) {
					vc.push_back(c12[i][j - n / 2]);
				}
				else if (j < n / 2) {
					vc.push_back(c21[i - n / 2][j]);
				}
				else {
					vc.push_back(c22[i - n / 2][j - n / 2]);
				}
			}
			c.push_back(vc);
		}
	}
	return c;
}


int main() {
	//试了一下发现用数组有诸多困难,感觉还是用vector好了
	vector<vector<int>>a, b, c,cc;
	const int n = 8;
	for (int i = 0; i < n; ++i) {
		vector<int>va, vb;
		for (int j = 0; j < n; ++j) {
			va.push_back(i == j ? 1 : 2);
			vb.push_back(i == j ? 2 : 1);
		}
		a.push_back(va);
		b.push_back(vb);
	}
	cc = smmr(a, b);
	for (int i = 0; i < n; ++i) {
		vector<int>vc;
		for (int j = 0; j < n; ++j) {
			int value = 0;
			for (int k = 0; k < n; ++k) {
				value += a[i][k] * b[k][j];
			}
			vc.push_back(value);
		}
		c.push_back(vc);
	}
	cout << "分治算法求解结果:" << endl;
	for (auto& t : cc) {
		for (auto tt : t) {
			cout << tt << " ";
		}
		cout << endl;
	}
	cout << "普通算法的结果" << endl;
	for (auto& t : c) {
		for (auto tt : t) {
			cout << tt << " ";
		}
		cout << endl;
	}
}

Java版本,与C++版本的相差无几,Java的数据和C++的vector很像,但Java的数据更方便省事很多,但总体上没太大差别

public class algo_matrix {
    public int[][] add(int[][] a,int[][]b){
        int n=a.length;
        int[][] c=new int[n][n];
        for(int i=0;i<n;++i){
            for(int j=0;j<n;++j){
                c[i][j]=a[i][j]+b[i][j];
            }
        }
        return c;
    }
    public int[][] smmr(int[][] a,int[][]b){
        int n=a.length;
        int[][] c=new int[n][n];
        if(n==1){
            c[0][0]=a[0][0]*b[0][0];
        }else{
            int[][] a11=new int[n/2][n/2],a12=new int[n/2][n/2],a21=new int[n/2][n/2],a22=new int[n/2][n/2]
                    ,b11=new int[n/2][n/2],b12=new int[n/2][n/2],b21=new int[n/2][n/2],b22=new int[n/2][n/2]
                    ,c11=new int[n/2][n/2],c12=new int[n/2][n/2],c21=new int[n/2][n/2],c22=new int[n/2][n/2];
            for(int i=0;i<n/2;++i){
                for(int j=0;j<n/2;++j){
                    a11[i][j]=a[i][j];
                    a12[i][j]=a[i][j+n/2];
                    a21[i][j]=a[i+n/2][j];
                    a22[i][j]=a[i+n/2][j+n/2];
                    b11[i][j]=b[i][j];
                    b12[i][j]=b[i][j+n/2];
                    b21[i][j]=b[i+n/2][j];
                    b22[i][j]=b[i+n/2][j+n/2];
                }
            }
            c11 = add(smmr(a11, b11), smmr(a12, b21));
            c12 = add(smmr(a11, b12), smmr(a12, b22));
            c21 = add(smmr(a21, b11), smmr(a22, b21));
            c22 = add(smmr(a21, b12), smmr(a22, b22));
            for(int i=0;i<n;++i){
                for(int j=0;j<n;++j){
                    //用三目运算符,判断语句的简写,一句语句就解决问题,C++也可以这么做的
                    c[i][j]=(i<n/2 &&j<n/2)?c11[i][j]:(i<n/2)?c12[i][j-n/2]:(j<n/2)?c21[i-n/2][j]:c22[i-n/2][j-n/2];
                }
            }
        }
        return c;
    }
    public static void main(String[] args) {
        int[][]a=new int[8][8],b=new int[8][8];
        for(int i=0;i<a.length;++i){
            for(int j=0;j<a.length;++j){
                a[i][j]=i==j?1:2;
                b[i][j]=i==j?2:1;
            }
        }
        int[][]c=new algo_matrix().smmr(a,b);
        for(int[] t:c){
            for(int tt:t){
                System.out.print(tt+" ");
            }
            System.out.println("-`");
        }
    }
}

python版本,充分利用列表推导式可以极大地简化代码编写,但需要注意的是其中存在一个坑,python的list的加法是定义了的,但它的加法并不是我们所期待的将两个list中对应的元素相加而是合并两个list,所以我们如果直接相加的会运行会出错,这个错误还不是容易找的。我们还是需要自己定义一个list的加法。

def add(a,b):
    c=[]
    for i in range(len(a)):
        c.append([a[i][x]+b[i][x] for x in range(len(a))])
    return c
def smmr(a,b):
    n=len(a)
    c=[]
    if n==1:
        c.append([a[0][0]*b[0][0]])
    else:
        a11,a12,a21,a22,b11,b12,b21,b22,c11,c12,c21,c22=[],[],[],[],[],[],[],[],[],[],[],[]
        for i in range(n//2):
            a11.append([x for x in a[i][:n//2]])
            a12.append([x for x in a[i][n//2:]])
            a21.append([x for x in a[i+n//2][:n//2]])
            a22.append([x for x in a[i+n//2][n//2:]])
            b11.append([x for x in b[i][:n // 2]])
            b12.append([x for x in b[i][n // 2:]])
            b21.append([x for x in b[i + n //2][:n // 2]])
            b22.append([x for x in b[i + n // 2][n // 2:]])
        c11 = add(smmr(a11, b11), smmr(a12, b21));
        c12 = add(smmr(a11, b12), smmr(a12, b22));
        c21 = add(smmr(a21, b11), smmr(a22, b21));
        c22 = add(smmr(a21, b12), smmr(a22, b22));
        c=[]
        for i in range(n):
            c.append([x for x in range(n)])
            for j in range(n):
                c[i][j]=c11[i][j] if i<n//2 and j<n//2 else c12[i][j-n//2] if \
                    i<n//2 else c21[i-n//2][j] if j<n//2 else c22[i-n//2][j-n//2]
    return c
a,b,c=[],[],[]
for i in range(8):
    a.append([1 if i==j else 2 for j in range(8)])
    b.append([2 if i==j else 1 for j in range(8)])

c=smmr(a,b)
print(c)

虽然是要简洁很多,但个人感觉还是Java更省事,因为针对这个算法我用Java实现时花的时间是最短的,python过于追求简洁反而有时容易出错。

分治算法的时间花费为Θ( n 3 n^3 n3),这样看来这个简单的分治算法并不具有优越性。

Strassen方法

而我们的Strassen方法则是要绕了一点,分为4步:

  1. 将输入矩阵A,B分解为n/2×n/2的子矩阵,同我们前面的SMMR算法
  2. 创建10个n/2×n/2的矩阵, S 1 , . . . , S 1 0 S_1,...,S_10 S1,...,S10,每个矩阵为第一步中的子矩阵的和或差
  3. 用前面创建的所有矩阵,递归的计算7个矩阵积 P 1 , . . . , P 7 P_1,...,P_7 P1,...,P7。每个矩阵都是n/2×n/2的
  4. 通过 P i P_i Pi计算出矩阵C的子矩阵 C 1 1 , C 1 2 , C 2 1 , C 2 2 C_11,C_12,C_21,C_22 C11,C12,C21,C22

其中步骤2中的矩阵:
S 1 = B 12 − B 22 S 2 = A 11 + A 12 S 3 = A 21 + A 22 S 4 = B 21 − B 11 S 5 = A 11 + A 22 S 6 = B 11 + B 22 S 7 = A 12 − A 22 S 8 = B 21 + B 22 S 9 = A 11 − A 21 S 10 = B 11 + B 12 S_1=B_{12}-B_{22}\\S_2=A_{11}+A_{12}\\S_3=A_{21}+A_{22}\\S_4=B_{21}-B_{11}\\S_5=A_{11}+A_{22}\\S_6=B_{11}+B_{22}\\S_7=A_{12}-A_{22}\\S_8=B_{21}+B_{22}\\S_9=A_{11}-A_{21}\\S_{10}=B_{11}+B_{12} S1=B12B22S2=A11+A12S3=A21+A22S4=B21B11S5=A11+A22S6=B11+B22S7=A12A22S8=B21+B22S9=A11A21S10=B11+B12
步骤3中,递归的计算7次矩阵乘法:
P 1 = A 11 ∗ S 1 P 2 = S 2 ∗ B 22 P 3 = S 3 ∗ B 11 P 4 = A 22 ∗ S 4 P 5 = S 5 ∗ S 6 P 6 = S 7 ∗ S 8 P 7 = S 9 ∗ S 10 P_1=A_{11}*S_1\\P_2=S_2*B_{22}\\P_3=S_3*B_{11}\\P_4=A_{22}*S_4\\P_5=S_5*S_6\\P_6=S_7*S_8\\P_7=S_9*S_{10} P1=A11S1P2=S2B22P3=S3B11P4=A22S4P5=S5S6P6=S7S8P7=S9S10
步骤4中,C的4个子矩阵:
C 11 = P 5 + P 4 − P 2 + p 6 C 12 = P 1 + P 2 C 21 = P 3 + P 4 C 22 = P 5 + P 1 − P 3 − P 7 C_{11}=P_5+P_4-P_2+p_6\\C_{12}=P_1+P_2\\C_{21}=P_3+P_4\\C_{22}=P_5+P_1-P_3-P_7 C11=P5+P4P2+p6C12=P1+P2C21=P3+P4C22=P5+P1P3P7
根据这4步,我们可以容易的在前面SMMR算法的基础上修改一下就能得到Strassen算法了。

虽然这个过程太绕了,但带来的收益是巨大的,如我们前面提到的,它的时间花费是Θ( n l g 7 n^{lg7} nlg7),对于一些大型矩阵,这个算法的优势还是可以的。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值