梯度下降法(代码实现)

数据来源
标注:在代码中,Living area里的数值都缩小了1000倍,Price里的数值都缩小了10倍
在这里插入图片描述

# 梯度下降法
import decimal
decimal.getcontext().prec = 100
decimal.getcontext().rounding = getattr(decimal, 'ROUND_CEILING')  # 总是趋向无穷大向上取整


def gradientdescent(learning_rate, a0, a1, a2, x1, x2, yi):  # x1:房屋面积 单位:千feet**2; x2:房间数量; yi:房屋价格的真实值 单位:万$
    count, a = 0, []
    for i in range(len(yi)):
        while count < 10000:
            a0, a1, a2, x1[i], x2[i], yi[i], learning_rate = decimal.Decimal(a0), decimal.Decimal(a1), \
                                                             decimal.Decimal(a2), decimal.Decimal(x1[i]), \
                                                             decimal.Decimal(x2[i]), decimal.Decimal(yi[i]), \
                                                             decimal.Decimal(learning_rate)
            J = decimal.Decimal(1 / 2) * (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) ** 2  # J:损失函数
            a0 = a0 - learning_rate * (a0 + a1 * x1[i] + a2 * x2[i] - yi[i])          # 梯度下降计算a0
            a1 = a1 - learning_rate * (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) * x1[i]  # 梯度下降计算a1
            a2 = a2 - learning_rate * (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) * x2[i]  # 梯度下降计算a2
            if 0 <= J - decimal.Decimal(1 / 2) * (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) ** 2 < 0.00001 and \
                    0 <= a0 + a1 * x1[i] + a2 * x2[i] - yi[i] <= 0.000001 and \
                    0 <= (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) * x1[i] <= 0.000001 and \
                    0 <= (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) * x2[i] <= 0.000001:
                a.append([a0, a1, a2])
                break
            count += 1
    print(f'第一组训练值得到的最小的损失函数对应的线性回归函数的系数为{a[0]},'
          f'第二组训练值得到的最小的损失函数对应的线性回归函数的系数为{a[1]},'
          f'第三组训练值得到的最小的损失函数对应的线性回归函数的系数为{a[2]},'
          f'第四组训练值得到的最小的损失函数对应的线性回归函数的系数为{a[3]},'
          f'第五组训练值得到的最小的损失函数对应的线性回归函数的系数为{a[4]}.')
    return a


def verify(A, x1, x2):
    Y_list = []  # 用梯度下降法得到的线性回归模型算出来的预测值,用来与真实值比较
    for i in range(len(A)):
        Y = A[i][0] + A[i][1] * decimal.Decimal(x1[i]) + A[i][2] * decimal.Decimal(x2[i])
        Y_list.append(Y)
    print(f'第一个预测值:{Y_list[0]}, 第二个预测值:{Y_list[1]}, 第三个预测值:{Y_list[2]}, 第四个预测值:{Y_list[3]}, 第五个预测值:{Y_list[4]}')


grad = gradientdescent(0.1, 0, 1, 1, [2.104, 1.600, 2.400, 1.416, 3.000], [3, 3, 3, 2, 4], [40.0, 33.0, 36.9, 23.2, 54.0])
verify(grad, [2.104, 1.600, 2.400, 1.416, 3.000], [3, 3, 3, 2, 4])

运行结果
第一组训练值得到的最小的损失函数对应的线性回归函数的系数为[Decimal(‘3.673876718510160019204096914391439045200052863069541805224781098532263420653296705003597778074836454’), Decimal(‘7.956852954170839274877895466755233374776639040226408202927314923289800060092752530733368662920025895’), Decimal(‘6.528301555304797803824627476965244426719948529777984312552428810084841757314553422970788265218600869’)]
第二组训练值得到的最小的损失函数对应的线性回归函数的系数为[Decimal(‘3.353446100142499098521371677672852926644124684676268631396267923277963110803997013853648209750483478’), Decimal(‘7.495432863721407526326806306254629650582593903127004441195724700753264064267482882924187465251742628’), Decimal(‘5.884620529127840587211435894661066871093601478411576870601004786757149856463281371312402722098662467’)]
第三组训练值得到的最小的损失函数对应的线性回归函数的系数为[Decimal(‘3.135494425141821066112276029498382388863205909527359943995108924971284162034302310654418104363015159’), Decimal(‘7.024657245719942996649057004873011295940022904084983333038654930770407022777703405028841245012233646’), Decimal(‘5.635109451587064370976472825653698091613144122668925474037909641308687565706506877418520949275162127’)]
第四组训练值得到的最小的损失函数对应的线性回归函数的系数为[Decimal(‘2.932664512997315719856367901580875346317498138265050518039295808522626058363133719158401876407202801’), Decimal(‘6.766170805682985398594123474790777912286411864689085906541723388972190721483751744936315466728731920’), Decimal(‘5.343218969545421141943483626620091223835148329903316963997779537164939132378513760247860835592288697’)]
第五组训练值得到的最小的损失函数对应的线性回归函数的系数为[Decimal(‘3.824120604491861713890478335961970487109515983399412363241708292688614716083923211794014256095138669’), Decimal(‘9.173102252718259567640495474325807705895776858862129039452930123936541738674159393254422123218779881’), Decimal(‘5.664143162483457537482490554983683650726319461018005602613825363893848802096210244702860349958431301’)].
第一个预测值:40.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000,
第二个预测值:33.00000026948027356800674103850984781898315054905535011149023455240818981736569586449650965059031408,
第三个预测值:36.90000016963087674708453417429737639448552730722127339012969189134115341549318534685331979245893810,
第四个预测值:23.20000031293526482334899493166328343044307765994809569429857142169715129090130832355990228704099924,
第五个预测值:54.00000001258047056674192697887412820770212440405782189205580012007363514049124237036872202558520353

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值