刘二大人《PyTorch深度学习实践》多维特征输入

07.处理多维特征的输入_哔哩哔哩_bilibili

对于多维样本的分类问题,一个样本不再只有一个值(特征)

所以\hat{y}^{(i)} =\sigma(\sum_{n=1}^{N}x_n^{(i)}*w_n+b)

每一个特征值都有自己的权重w_n

\sum_{n=1}^{N}x_n^{(i)}*w_n=[x_1^{(i)}.......x_N^{(i)}]\begin{bmatrix}w_1 \\ . \\ . \\ . \\ w_N \end{bmatrix}

 

对于Mini_Batch情况

\begin{bmatrix}\hat{y}^{(1)} \\ . \\ . \\ . \\ \hat{y}^{(N)} \end{bmatrix}=\begin{bmatrix}\sigma({z}^{(1)}) \\ . \\ . \\ . \\ \sigma({z}^{(N)})\end{bmatrix}=\sigma(\begin{bmatrix}{z}^{(1)} \\ . \\ . \\ . \\ {z}^{(N)}\end{bmatrix})

z^{(1)}=[x_1^{(1)}...x_N^{(1)}]\begin{bmatrix}w_1 \\ . \\ . \\ . \\ w_N\end{bmatrix}+b\\...\\...\\...\\z^{(N)}=[x_1^{(N)}...x_N^{(N)}]\begin{bmatrix}w_1 \\ . \\ . \\ . \\ w_N\end{bmatrix}+b

即 

输入N维,输出1维(N维空间到1维空间的非线性变换)

也可以不直接从N维到1维,最后是一维即可,引入多次非线性变换,参数具有更好的拟合性,有利于提高网络的泛化能力

 

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np



xy=np.loadtxt('./diabetes.csv',delimiter=',',dtype=np.float32)
x_data = torch.from_numpy(xy[:,:-1])      #最后一列不要 ,向量形式
y_data = torch.from_numpy(xy[:,[-1]])       #只要最后一列,且为矩阵形式


class Multiple_DimensionModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1=torch.nn.Linear(8,6)
        self.linear2=torch.nn.Linear(6,4)
        self.linear3=torch.nn.Linear(4,1)
        self.sigmoid=torch.nn.Sigmoid()  # 将其看作是网络的一层,而不是简单的函数使用

    def forward(self,x):
        x=self.sigmoid(self.linear1(x))
        x=self.sigmoid(self.linear2(x))
        x=self.sigmoid(self.linear3(x))
        return x



model=Multiple_DimensionModel()

criterion=torch.nn.BCELoss(size_average=True)
optimizer=torch.optim.Adam(model.parameters(),lr=0.05)

epoch_list=[]
cost_list=[]

for epoch in range(5000):
    y_pred=model(x_data)
    loss=criterion(y_pred,y_data)
    print(epoch,loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    epoch_list.append(epoch)
    cost_list.append(loss.item())  # loss


plt.plot(epoch_list,cost_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()



if __name__ == "__main__":
    main()

这样8-6-4-1维,使用adam优化器效果更好

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值