- Variable 和 Tensor 的区别:Variable提供了自动求导的功能
Tensor→Variable :
Variable(a)
- 数据读取 Dataset
torch.utils.data.Dataset 是读取数据的抽象类,重写时只需要定义 __len__ 和 __getitem__。
class myDataset(Dataset):
def __init__(self, csv_file):
self.csv_data = pd.read_csv(csv_file)
def __len__(self):
return len(self.csv_data)
def __getitem__(self, idx):
data = self.csv_data[idx]
return data
定义batch, shuffle或多线程读取则需要定义迭代器DataLoader:
dataiter = DataLoader(mtDataset, batch_size = 32, shuffle = True,
collate_fn = default_collate)
### collate_fn是表示如何取样本的
- 定义神经网络 nn.Module
所有层结构和损失函数都来自torch.nn,都要从基类nn.Module继承
class net_name(nn.Module):
def __init__(self, other_argument):
super(net_name, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size)
def forward(self, x):
x = self.conv1(x)
return x
损失函数的定义
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
- 优化 torch.optim
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
# 梯度归零
optimizer.zeros()
# 反向传播
loss.backword()
# 参数更新
optimizer.step()
- 模型保存与加载
## 保存全部模型
torch.save(model, './model.pth')
## 保存模型参数
torch.save(model.state_dict(), path)
## 读取整个模型
torch.load('model.pth')
## 仅读取模型参数
model.load_state_dic(torch.load('model.pth')) # 需要先导入模型结构