pytorch 线性模型
from matplotlib import pyplot as plt
import torch
from torch import nn
X = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
Y = 4*X + 5 + torch.rand(X.size())
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, s
原创
2021-06-16 20:53:26 ·
92 阅读 ·
0 评论