pytorch dataset自定义_pytorch入门练手:一个简单的CNN模型

v2-eca593cc330d84fe39fc6a85d9e0afe3_1440w.jpg?source=172ae18b

由于新型冠状肺炎疫情一直没能开学,在家自己学习了一下pytorch,本来说按着官网的60分钟教程过一遍的,但是CIFAR-10数据库的下载速度太慢了……

这台电脑里也没有现成的数据库,想起之前画了一些粒子的动量分量分布图,干脆拿来用了,也没期待它能表现得多好,主要图一个练手。(事实证明它表现相当差,不过这也在意料之中)

那么开始。

import 

接下来定义读取和处理图片的函数,图片尺寸是432x288,把它切成中间的288x288,再缩小成32x32。这样的处理单纯是为了让模型训练得快一点,毕竟这次练手本身的目的不是训练一个高精度的模型,而是训练一个模型。(而且话说回来这电脑也莫得英伟达高性能图形处理器,(笑))

这里其实两个函数写成一个就行,但是我懒得改了。

PATH = '/Users/huangyige/Downloads/fig/'

def load_img(imgname):
    #here, only consider pion at 7.7 and 14.5 GeV, px.
    img = Image.open(imgname).convert('RGB')
    return img

def process_img(img):
    img = img.crop((72,0,360,288))
    img = img.resize((32,32))
    return img

然后随意拉张图进来看看。

img = load_img(PATH+'Pion-7.7GeV-7-P1.png')
img = process_img(img)
plt.imshow(img)

v2-773ade98eeee0b35b4228191986bd637_b.jpg

只能隐约能看出来有两个峰2333333,这样的数据集能训练出来个鬼咯~

然后生成两个数据集的文件名列表(附带标签)的文档。

def generate_file(name,num_range):
    with open('./'+name+'.txt','w') as f:
        for energy in ['7.7','14.5']:
            for _ in num_range:
                imgname = PATH + 'Pion-' + energy + 'GeV-' + str(_+1) + '-P1.png'
                f.write(imgname+' '+energy+'n')
    return
generate_file('train',range(0,70))
generate_file('test',range(70,90))

就别问我为什么训练集就70张图,测试集就20张图了,只有这么点数据……可以打开文档看看效果。

v2-2cbedc8c7f920b62c24cdd1c9b7e9a50_b.jpg

差不多就这样,没什么问题。接下来定义自定义Dataset类。

class sets(torch.utils.data.Dataset):
    def __init__(self,datatxt,transform=None):
        super(sets,self).__init__()
        imgs = []
        with open(datatxt,'r') as f:
            for line in f:
                line = line.rstrip('n')
                words = line.split(' ')
                imgs.append((words[0],words[1]))
        self.imgs = imgs
        self.transform = transform
        return
    def __getitem__(self,index):
        imgname,label_o = self.imgs[index]
        img = load_img(imgname)
        img = process_img(img)
        if label_o == '7.7':
            label = 0
        else:
            label = 1
        if self.transform is not None:
            img = self.transform(img)
        return img,label
    def __len__(self):
        return len(self.imgs)

以及DataLoader。

train_set = sets('./train.txt',transforms.ToTensor())
test_set = sets('./test.txt',transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=1,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=1)

batch_size选1是不是很扯,哈哈哈我也这么觉得。如果需要做数据增强,在初始化sets时,transform参数用transforms.Compose[transforms.ToTensor(),...]这样多填几个就行了。

然后是模型的结构。

class 

直接把pytorch官网的tutorial里CIFAR-10的模型拉出来用了,正好我已经把数据变成了32x32,参数都不用改。(修改:最后一个全链接层的神经元数应该是2而不是10,还是得改一下的)

选损失函数和优化器。

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

pytorch没有现成的算accuracy的函数所以自己写一个。

def accuracy(net,test_loder):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs,labels = data
            outputs = net(inputs)
            _,pred = torch.max(outputs.data,1)
            total += labels.size(0)
            correct += (pred==labels).sum().item()
    acc = 100.0*correct/total
    return acc

然后就可以开始训练了,本来也只是个玩具模型,所以2代就够了。不得不插一嘴,keras用起来确实要方便一点。

for epoch in range(2):
    running_loss = 0.0
    for i,data in enumerate(train_loader,0):
        inputs,labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        acc = accuracy(net,test_loader)
        print('r[epoch %d >%3.d<] loss:%.3f,acc:%.1f%%'%(epoch+1,i+1,loss,acc),end='')
    print('')
print('Done!')

看看效果:

v2-c3e899adeee3a7b0b8abcb9d30313264_b.jpg

精度精准地锁定在50%,也就是说这模型在纯猜~whatever,这次练手主要是熟悉pytorch怎么用,模型本身的质量不重要。

最后保存一下模型:

torch.save(net.state_dict(),'./model.pth')

想开学啊啊啊啊啊……

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值