最近使用pytorch踩过的一些坑,记录一下,偏应用。
1.图片加载
pytorch中的datasets.ImageFolder函数直接可以读取自己的图片的数据集。
数据集存放:
把每一类的图片放到一个文件夹里面,加载时地址只用写到类别文件夹的上一级目录。例如下图中dataset文件夹存放了4个类别的图片,那么图片加载时写入的地址就是** ‘F:\dataset’** 。datasets.ImageFolder会自动根据文件夹类别给数据打上标签。
from torchvision import datasets, transforms
import torch
import os
ef load_data(root_path, dir, batch_size, phase):
transform_dict = {
'tar':transforms.Compose(
[transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),])}
data = datasets.ImageFolder(root=os.path.join(root_path,dir), transform=transform_dict[phase]) ##即各类别文件夹所在目录的上一级目录
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True,drop_last=False, num_workers=4)
#设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。drop_last=False不丢弃
return data_loader
上面这段示例代码可以当作模板使用,但其实我还有一个问题没搞懂,就是transforms.Normalize标准化时输入的平均值和标准差为什么不是0.5,而是一堆奇怪的小数,有哪位大佬路过可以帮忙解答一下。
2.模型搭建
pytorch的模型搭建这里,没有什么特别要记录的,网上很多例子,照着搭自己的模型就行。只有一点,是在我真正开始动手操作的时候才发现的,之前照着书学习的时候没发现或者看到了没有注意就略过了。
在pytorch中搭建一个My_Net类作为自己的模型,在调用时按照下面流程调用传入自己的数据就行,它会直接执行My_Net类中的forward 函数,完成前向传播,不需要单独调用forward函数。
class My_Net(nn.Module):
def __init__(self,):
super(My_Net,self).__init__()
……
def forward(data):
……
return result
def my_loss( ):
……
return loss
#调用
data
model = My_Net()
result=model(data)