[TOC] #微分方程概述 微分方程在各个领域应用颇多。 形如
y′=f(x,y)
的微分方程表示了系统的变化信息,如果在加上初始条件
(x0,y0)
,那么就可以求出系统整体随时间变化的信息。 可以说,**正是微分方程将物理世界模型化**。 #方向场与积分曲线 方向场(`direction field`)与积分曲线(`integral curve`)的关系,可以用下面的式子简要表示:
y′和(x0,y0) ,根据下面方法迭代:
由上图可见,欧拉法存在一定的误差,并且误差会累计。当步长越小误差也就越小拟合效果越好。 这种情况下,误差和步长的关系是:
{y′=f(x,y)y(x)⟺DirectionField⟺IntegralCurve
其中,当
f(x,y),f′(x,y)
在邻域内连续时,积分曲线不会相交也不会相切,解存在且唯一(exist and unique)。 下面,举函数
y′=−x/y
的方向场与积分曲线:
%下面的函数dirfield需要用到参考资料中的函数
%定义函数y'
f = @(x,y) -x/y
%方向场的一些简单可视化
ezplot(f,[-2,2,-2,2])
ezsurf(f,[-2,2,-2,2])
ezcontour(f,[-2,2,-2,2]); colorbar
%画出方向场与积分曲线
dirfield(f,-2:0.2:2,-2:0.2:2)
hold on
for y0=-0.2:0.5:2
[ts,ys] = ode45(f,[-2,2],y0); plot(ts,ys)
end
title('dy/dx=-x/y的方向场与积分曲线')
hold off
![这里写图片描述](http://7xlwwh.com1.z0.glb.clouddn.com/sXshot-0367.png) #微分方程的解析解法 微分方程的解析解法通常是将
x,y
分别移到等式的一边。 下面以
y′=2y+1
为例,移项后
dy2y+1=dx
,所以有
12ln(|2y+1|)+c1=x+c2
,进而有
|2y+1|=Ce2x
,最后解得:
y=Ce2x−12
其实,
ex
就是根据微分方程
y′=y
在
(0,1)
的初始条件下确定的。 使用matlab的解析解法为:
dsolve('Dy=2*y+1','x')
%输出为: (C2*exp(2*x))/2 - 1/2
%求解e^x
dsolve('Dy=y','y(0)=1','x')
%输出为: exp(x)
#微分方程的数值解法 ##欧拉法 欧拉法的核心是,设定步长为
h
,然后已知
⎧⎩⎨⎪⎪xn+1yn+1Slope=xn+h=yn+h∗Slope=y′n
ODE数值解法的matlab程序为:
[xs,ys] = ode45(f,[-2,2],y0)
##欧拉法的缺点
e∼c∗h
如果函数时而`convex`时而`concave`,这时候误差的变化便难以预测。
#---------------------------------------------------------
##凸函数
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
#定义产生下一个点的函数
def nextPoint(x,y,h):
xn = x + h
slope = 2*x
yn = y + h*slope
return (xn,yn)
#定义产生点的生成器
def pointGenerator(x,y,h):
while True:
yield nextPoint(x,y,h)
(x,y) = nextPoint(x,y,h)
#根据输入的起始终止点以及步长,输出可以用于画图的参数
def getXY(x,y,h):
x1,y1=[],[]
for i in pointGenerator(x,y,h):
xi = i[0]
yi = i[1]
if xi>2.5:
break
else:
x1.append(xi)
y1.append(yi)
x1.insert(0,-2)
y1.insert(0,4)
return (x1,y1)
#---------------------------------------------------------
#凹函数
#大部分和上面相同,只是将`nextPoint`函数重新定义
def nextPoint(x,y,h):
xn = x + h
slope = -2*x
yn = y + h*slope
return (xn,yn)
def pointGenerator(x,y,h):
while True:
yield nextPoint(x,y,h)
(x,y) = nextPoint(x,y,h)
def getXY(x,y,h):
x1,y1=[],[]
for i in pointGenerator(x,y,h):
xi = i[0]
yi = i[1]
if xi>2.5:
break
else:
x1.append(xi)
y1.append(yi)
x1.insert(0,-2)
y1.insert(0,-4)
return (x1,y1)
x = np.arange(-2,2.1,0.1)
y = -x**2
x1,y1 = getXY(-2,-4,0.1)
x2,y2 = getXY(-2,-4,0.4)
x3,y3 = getXY(-2,-4,0.6)
plt.plot(x,y,'b--', linewidth=1,label='raw line')
plt.plot(x1,y1,'r',label='h=0.1')
plt.plot(x2,y2,'g',label='h=0.4')
plt.plot(x3,y3,'c',label='h=0.6')
plt.autoscale()
plt.xlim(-2.5,2.5)
plt.legend(loc='best')
plt.title('concave function with different h')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
##改进欧拉法之步长 步长的改进参考上文,步长越小误差越小。 ##改进欧拉法之斜率 核心是:计算斜率不只考虑当前的点,也考虑之后的点的斜率。 该方法一般被称作`runge-kutta`法,上文只用到一个斜率的被称为`RK1`,下面将要阐述的是`RK2`,同时在绝大多数数值计算工具中,`RK4`的使用最为广泛。
⎧⎩⎨⎪⎪xn+1yn+1Slope=xn+h=yn+h∗Slope=(y′n+y′n+1)/2
由上图可看,RK2的效果已经比RK1好太多的。
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
#RK1
def nextPoint(x,y,h):
xn = x + h
slope = 2*x
yn = y + h*slope
return (xn,yn)
def pointGenerator(x,y,h):
while True:
yield nextPoint(x,y,h)
(x,y) = nextPoint(x,y,h)
def getXY(x,y,h):
x1,y1=[],[]
for i in pointGenerator(x,y,h):
xi = i[0]
yi = i[1]
if xi>2.5:
break
else:
x1.append(xi)
y1.append(yi)
x1.insert(0,-2)
y1.insert(0,4)
return (x1,y1)
#RK2
def nextPoint2(x,y,h):
xn = x + h
slope = (2*x + 2*xn)/2
yn = y + h*slope
return (xn,yn)
def pointGenerator2(x,y,h):
while True:
yield nextPoint2(x,y,h)
(x,y) = nextPoint2(x,y,h)
def getXY2(x,y,h):
x1,y1=[],[]
for i in pointGenerator2(x,y,h):
xi = i[0]
yi = i[1]
if xi>2.5:
break
else:
x1.append(xi)
y1.append(yi)
x1.insert(0,-2)
y1.insert(0,4)
return (x1,y1)
#DATA
x = np.arange(-2,2.1,0.1)
y = x**2
x1,y1 = getXY(-2,4,1)
x2,y2 = getXY2(-2,4,1)
x3,y3 = getXY(-2,4,0.5)
x4,y4 = getXY2(-2,4,0.5)
#PLOT
plt.plot(x,y,'k', linewidth=1,label='raw line')
plt.plot(x4,y4,'r--',label='h=0.5 RK2')
plt.plot(x2,y2,'b--',label='h=1 RK2')
plt.plot(x3,y3,'c--',label='h=0.5 RK1')
plt.plot(x1,y1,'g--',label='h=1 RK1')
plt.autoscale()
plt.xlim(-2.5,2.5)
plt.legend(loc='best')
plt.title('convex function with different h and RKn')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()