pytorch实现交叉验证

pytorch实现交叉验证

一般的交叉验证是对神经网络回归分类的代码
我这里是针对图像分类来的,对于目标检测这些的话,把对应读取数据的函数修改一下就行了

实现交叉验证的Dataset

import torch
import torch.nn as nn
from torch.utils.data.dataset import *
from PIL import Image
from torch.nn import functional as F
import random

class KZDataset(Dataset):
    def __init__(self, txt_path=None, ki=0, K=5, typ='train', transform=None, rand=False):
        '''
        txt_path: 所有数据的路径,我的形式为(单张图片路径 类别\n)
        	img1.png 0
        	...
     	    img100.png 1
     	ki:当前是第几折,从0开始,范围为[0, K)
     	K:总的折数
     	typ:用于区分训练集与验证集
     	transform:对图片的数据增强
     	rand:是否随机
        '''

        self.all_data_info = self.get_img_info(txt_path)
        
        if rand:
	        random.seed(1)
        	random.shuffle(self.all_data_info)
        leng = len(self.all_data_info)
        every_z_len = leng // K
        if typ == 'val':
            self.data_info = self.all_data_info[every_z_len * ki : every_z_len * (ki+1)]
        elif typ == 'train':
            self.data_info = self.all_data_info[: every_z_len * ki] + self.all_data_info[every_z_len * (ki+1) :]
            
        self.transform = transform

    def __getitem__(self, index):
    	# Dataset读取图片的函数
        img_pth, label = self.data_info[index]
        img = Image.open(img_pth).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(txt_path):
    	# 解析输入的txt的函数
    	# 转为二维list存储,每一维为 [ 图片路径,图片类别]
        data_info = []
        data = open(txt_path, 'r')
        data_lines = data.readlines()
        for data_line in data_lines:
            data_line = data_line.split()
            img_pth = data_line[0]
            label = int(data_line[1])
            data_info.append((img_pth, label))
        return data_info   

运用KZDataset

这里我只写调用的伪代码,按自己需求改
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
transfm = transforms.Compose([
		transforms.Resize((384, 384)),
		trainsforms.ToTensor(),
		transforms.Normalize(norm_mean, norm_std)])
for ki in range(K):
	trainset = KZDataset(txt_path='data.txt', ki=ki, K=K, typ='train', transform=transfm, rand=True)
	valset = KZDataset(txt_path='data.txt', ki=ki, K=K, typ='val', transform=transfm, rand=True)
	train_loader = DataLoader(
         dataset=trainset,
         batch_size=batchs,
         shuffle=True)
	val_loader = DataLoader(
         dataset=valset,
         batch_size=batchs,
     )
     for epoch in range(epoches):
	     for i, (inputs, labels) in enumerate(train_loader):
	     	pass
	     	'''
	     	训练过程
	     	'''
	     for i, (inputs, labels) in enumerate(val_loader):
	     	pass
	     	'''
	     	验证过程
	     	'''
  • 15
    点赞
  • 60
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值