Pytorch极简入门教程(一)—— 数据集的读取

# -*- coding: utf-8 -*-
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

data = pd.read_csv("dataset/Income1.csv")
data.info()

print("data:\t", data)

# 绘制散点图
plt.scatter(data.Education, data.Income)
# x轴名称
plt.xlabel("Education")
# y轴名称
plt.ylabel("Income")
# 查看图形
plt.show()

from torch import nn
#X = data.Education.values.reshape(-1, 1).shape
X = data.Education.values.reshape(-1, 1).astype(np.float32)
print("X:\t", X)
# 将numpy 转换成 Tensor
X = torch.from_numpy(X)
print("X.type:\t", X.type)

Y = torch.from_numpy(data.Income.values.reshape(-1,1).astype(np.float32))
print("Y.type:\t", Y.type)

model = nn.Linear(1, 1) # out = w@input + b 等价于 model(input)
# 损失函数
loss_fn = nn.MSELoss()
# 优化器 主要优化model 中的权重w和b
opt = torch.optim.SGD(model.parameters(), lr=0.001)

for epoch in range(5000):
    for x, y in zip(X, Y):
        # 使用模型预测
        y_pred = model(x)
        # 根据预测结果计算损失
        loss = loss_fn(y, y_pred)
        # 把变量的梯度清零
        opt.zero_grad()
        # 反向传播求解梯度
        loss.backward()
        # 优化模型参数
        opt.step()
print("model.bias:\t", model.bias)
print("model.weight:\t", model.weight)
print("model(X):\t", model(X))
plt.scatter(data.Education, data.Income)
# model(X).data 取出里面的值
plt.plot(X.numpy(), model(X).data.numpy(), c='r')
plt.show()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30 entries, 0 to 29
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   Unnamed: 0  30 non-null     int64  
 1   Education   30 non-null     float64
 2   Income      30 non-null     float64
dtypes: float64(2), int64(1)
memory usage: 848.0 bytes
data:	     Unnamed: 0  Education     Income
0            1  10.000000  26.658839
1            2  10.401338  27.306435
2            3  10.842809  22.132410
3            4  11.244147  21.169841
4            5  11.645485  15.192634
5            6  12.086957  26.398951
6            7  12.488294  17.435307
7            8  12.889632  25.507885
8            9  13.290970  36.884595
9           10  13.732441  39.666109
10          11  14.133779  34.396281
11          12  14.535117  41.497994
12          13  14.976589  44.981575
13          14  15.377926  47.039595
14          15  15.779264  48.252578
15          16  16.220736  57.034251
16          17  16.622074  51.490919
17          18  17.023411  61.336621
18          19  17.464883  57.581988
19          20  17.866221  68.553714
20          21  18.267559  64.310925
21          22  18.709030  68.959009
22          23  19.110368  74.614639
23          24  19.511706  71.867195
24          25  19.913043  76.098135
25          26  20.354515  75.775218
26          27  20.755853  72.486055
27          28  21.157191  77.355021
28          29  21.598662  72.118790
29          30  22.000000  80.260571
X:	 [[10.      ]
 [10.401338]
 [10.84281 ]
 [11.244147]
 [11.645485]
 [12.086957]
 [12.488295]
 [12.889632]
 [13.29097 ]
 [13.732442]
 [14.13378 ]
 [14.535117]
 [14.976588]
 [15.377927]
 [15.779264]
 [16.220736]
 [16.622074]
 [17.02341 ]
 [17.464884]
 [17.86622 ]
 [18.26756 ]
 [18.70903 ]
 [19.110369]
 [19.511705]
 [19.913044]
 [20.354515]
 [20.755854]
 [21.15719 ]
 [21.598661]
 [22.      ]]
X.type:	 <built-in method type of Tensor object at 0x0000022A601ABC28>
Y.type:	 <built-in method type of Tensor object at 0x0000022A65680DB8>
model.bias:	 Parameter containing:
tensor([-32.6934], requires_grad=True)
model.weight:	 Parameter containing:
tensor([[5.1265]], requires_grad=True)
model(X):	 tensor([[18.5715],
        [20.6289],
        [22.8921],
        [24.9496],
        [27.0070],
        [29.2702],
        [31.3277],
        [33.3851],
        [35.4426],
        [37.7058],
        [39.7632],
        [41.8207],
        [44.0839],
        [46.1413],
        [48.1988],
        [50.4620],
        [52.5194],
        [54.5769],
        [56.8401],
        [58.8975],
        [60.9550],
        [63.2182],
        [65.2756],
        [67.3331],
        [69.3905],
        [71.6537],
        [73.7112],
        [75.7686],
        [78.0318],
        [80.0893]], grad_fn=<AddmmBackward>)

在这里插入图片描述
在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值