第六章 线性回归的拓展 - 非线性回归


在这里插入图片描述

0 导入库

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import optimize

1 定义方程形式

# 线性形式
def f_1(x, A, B):
    return A + B*x

# 复合函数形式
def f_2(x, A, B):
    return A * B**x 

2 数据

x:病人住院的天数
y:是病人出院后长期恢复的预后指数(指数的数值越大表示预后的结局越好)

x = np.array([2, 5, 7, 10, 14, 19, 26, 31, 34, 38, 45, 52, 53, 60, 65]) # 住院天数
y = np.array([54, 60, 45, 37, 35, 25, 20, 15, 18, 13, 8, 11, 8, 4, 6]) # 预后指数

3 拟合方程

A1, B1 = optimize.curve_fit(f_1, x, y)[0]
A2, B2 = optimize.curve_fit(f_2, x, y)[0]

y1 = A1 + B1*x
y2 = A2 * B2**x

4 画图

from matplotlib.font_manager import FontProperties
font_set = FontProperties(fname=r"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", size=16) 

def runplt():
    plt.figure() # 定义figure
    plt.title(u'对重伤病人出院后的长期恢复情况进行预测',fontproperties=font_set, size=20)
    plt.xlabel(u'住院天数(day)',fontproperties=font_set)
    plt.ylabel(u'预后指数',fontproperties=font_set)
    plt.axis([0, 70, 0, 70]) # 坐标轴
    plt.grid(True) # 网格线
    return plt

plt = runplt()
plt.scatter(x, y, c='b', label='Data')

在这里插入图片描述
标准相关系数
非线性趋势

plt = runplt()
plt.scatter(x, y, c='b', label='Data')
plt.plot(x, y1, 'g-', label='Linear Regression')
plt.plot(x, y2, 'r:', label='Compound Regression')
plt.legend()

在这里插入图片描述
复合函数回归比线性回归拟合效果更好

5 计算残差

residual1 = y - y1 # 线性回归残差
residual2 = y - y2 # 复合函数回归残差

residual1

在这里插入图片描述

residual2

在这里插入图片描述

6 计算总平方和

# 线性回归残差平方和
sse1 = 0
for i in residual1:
    sse1 += i**2
# 复合函数回归残差平方和
sse2 = 0
for i in residual2:
    sse2 += i**2

sse1

在这里插入图片描述

sse2

在这里插入图片描述

在这里插入图片描述

常用的可转化成线性的非线性回归模型
在这里插入图片描述

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值