pytorch实践线性模型绘图2D

该代码示例定义了一个线性模型类,包括正向传播函数和计算损失的静态方法。它遍历不同权重值,计算对应的数据集上的均方误差(MSE),并通过matplotlib可视化MSE与权重的关系。
摘要由CSDN通过智能技术生成
import matplotlib.pyplot as plt
import numpy as np


class LinearModel:
    
    def __int__(self, w, b):
        self.w = w
        self.b = b

    @staticmethod
    def forward(w, x):
        return w * x

    @staticmethod
    def forward_with_intercept(w, x, b):
        return w * x + b

    @staticmethod
    def get_loss(w, x, y_origin, exp=2, b=None):
        if b:
            y = LinearModel.forward(w, x, b)
        else:
            y = LinearModel.forward(w, x)
        return pow(y_origin - y, exp)


if __name__ == '__main__':
    x_data = [1.0, 2.0, 3.0]
    y_data = [2.0, 4.0, 6.0]
    weight_data = []
    MSE_data = []

    # 设定实验的权重范围
    for w in np.arange(0.0, 4.1, 0.1):
        weight_data.append(w)
        loss_total = 0
        MSE = 0
        # 计算每个权重在数据集上的MSE平均平方方差
        for x_val, y_val in zip(x_data, y_data):
            loss_total += Linear
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值