在本次学习中,利用PyTorh实现了线性模型,总结如下:
经过之前的学习,我们知道无论是使用pytorh实现线性模型或者是手动实现线性模型,实际上步骤都分为以下几步:
代码展示(附有释义):
①准备数据集
import matplotlib.pyplot as plt
import torch
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
#准备数据集
x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])
print(x_data.shape)
②设计模型
#构造模型
class LinearModel(torch.nn.Module):#将模型定义成一个类,模型类继承至Module
def __init__(self):#构造函数,初始化对象调用
super(LinearModel,self).__init__()#调用父类的构造
#!!注意这里的(1,1)是指输入和输出的特征维度,并不代表权重和偏置
self.linear = torch.nn.Linear(1,1)#