import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
data = pd.read_csv('csv_data/test.csv')
data.head(3) # 查看前3行
data.info() # 查看类型、是否空
plt.scatter(data.num,data.debt) # 散点图
plt.xlabel('num')
plt.ylabel('debt')
from torch import nn
import torch
# 数据预处理
x = torch.from_numpy(data.num.values.reshape(-1,1).astype(np.float32))
y = torch.from_numpy(data.debt.values.reshape(-1,1).astype(np.float32))
# 转为tensor类型
model = nn.Linear(1,1)
loss_fn = nn.MSELoss() # loss函数
opt =torch.optim.SGD(model.parameters(),lr=0.001) # 优化
for epoch in range(100):
for a,b in zip(x,y):
y_pred = model(a) # 预测
loss=loss_fn(b,y_pred) # 损失
opt.zero_grad() # 梯度清零
loss.backward() # 计算梯度
opt.step() # 优化参数
with torch.no_grad():
print(epoch,' loss: ', loss_fn(model(x),y).data.item())
plt.scatter(data.num,data.debt)
plt.plot(x.numpy(),model(x).data.numpy(), c='r')
plt、nn 使用
最新推荐文章于 2021-10-08 09:19:53 发布