Pytorch 学习笔记(一)

  •  

  • Variable 和 Tensor 的区别:Variable提供了自动求导的功能

        Tensor→Variable : 

Variable(a)

  • 数据读取 Dataset

torch.utils.data.Dataset 是读取数据的抽象类,重写时只需要定义 __len__ 和 __getitem__。

class myDataset(Dataset):
    def __init__(self, csv_file):
        self.csv_data = pd.read_csv(csv_file)
    
    def __len__(self):
        return len(self.csv_data)

    def __getitem__(self, idx):
        data = self.csv_data[idx]
        return data

定义batch, shuffle或多线程读取则需要定义迭代器DataLoader:

dataiter = DataLoader(mtDataset, batch_size = 32, shuffle = True, 
                      collate_fn = default_collate)
### collate_fn是表示如何取样本的
  • 定义神经网络 nn.Module

所有层结构和损失函数都来自torch.nn,都要从基类nn.Module继承

class net_name(nn.Module):
    def __init__(self, other_argument):
        super(net_name, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size)
        
    def forward(self, x):
        x = self.conv1(x)
        return x

损失函数的定义

criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
  • 优化 torch.optim 
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)

# 梯度归零
optimizer.zeros()

# 反向传播
loss.backword()

# 参数更新
optimizer.step()
  • 模型保存与加载
## 保存全部模型
torch.save(model, './model.pth')

## 保存模型参数
torch.save(model.state_dict(), path)

## 读取整个模型
torch.load('model.pth')

## 仅读取模型参数 
model.load_state_dic(torch.load('model.pth')) # 需要先导入模型结构

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值