Pytorch搭建神经网络的一个简单框架
1.获取神经网络能用的数据(batch)
2.构建网络(具体看自己实现的网络,先不写)
3.建立优化器和训练
torch使用torch.utils.data.DataLoader读取数据,将数据集转换成最终能用的batch,一般搭建网络都需要这个。
import torch.utils.data as data
# dataset 假设dataset是原始数据集
# 先要定义一个class继承torch.utils.data.Dataset
class Mydataset(data.Dataset):
def __init__(self):
pass
def __getitem__(self):
pass
def __len__(self):
pass
# 通过这个子类将原始的数据集转换成DataLoader需要的Dataset类型(或者是IterableDataset 默认是Dataset, 具体看官方文档)。
dataset = Mydataset()
# 得到了Dataset类型的dataset了,然后带入到DataLoader中
# 通过DataLoader得到的就是batch了
batch = data.DataLoader(dataset,
batch_size=1,
shuffle=False,
num_workers=0)
# 得到的batch就可以提取标签和图片进行训练了
batchiterator = iter(batch)
# 建立优化器
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for step in range(iteration):
imgs, labels = next(batchiterator)
optimizer.zero_grad()
out = model(img)
loss = loss_fn(out, labels)
loss.backward()
optimizer.step()
# 以上可以简单地看做是一个网络搭建的基本框架