import matplotlib.pyplot as plt
import numpy as np
class LinearModel:
def __int__(self, w, b):
self.w = w
self.b = b
@staticmethod
def forward(w, x):
return w * x
@staticmethod
def forward_with_intercept(w, x, b):
return w * x + b
@staticmethod
def get_loss(w, x, y_origin, exp=2, b=None):
if b:
y = LinearModel.forward(w, x, b)
else:
y = LinearModel.forward(w, x)
return pow(y_origin - y, exp)
if __name__ == '__main__':
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
weight_data = []
MSE_data = []
# 设定实验的权重范围
for w in np.arange(0.0, 4.1, 0.1):
weight_data.append(w)
loss_total = 0
MSE = 0
# 计算每个权重在数据集上的MSE平均平方方差
for x_val, y_val in zip(x_data, y_data):
loss_total += Linear
pytorch实践线性模型绘图2D
于 2023-04-05 19:55:50 首次发布
该代码示例定义了一个线性模型类,包括正向传播函数和计算损失的静态方法。它遍历不同权重值,计算对应的数据集上的均方误差(MSE),并通过matplotlib可视化MSE与权重的关系。
摘要由CSDN通过智能技术生成