Pytorch写二分类框架

刚完成一个torch的二分类的工作,是实际的工作。在次对torch的分类框架进行一下总结。也是一个再次梳理思路的过程。

要做一个分类的工作,首先就是数据集。数据集的格式,以及训练时传入网络的数据集的形式。这里的torch代码就是我之前上传的资源的代码。链接:

pytorch框架,分类网络-深度学习文档类资源-CSDN下载

先进行数据集的制作。对应的是datasets.py,二分类,正样本是Yes,负样本是No。将Yes和No两个文件夹平行放入一个data的文件夹,取其他的名字也可以。

import troch
from torch.utils.data import Dataset,DataLoader
import cv2
from config import config
import os
form torchvision import transforms
import numpy as np
from PIL import Image
import math
import random

np.random.seed(666)

def get_files(file_dir,ratio):
    Yes = []
    labels_Yes = []
    No = []
    labels_No = []
    dir_list = []
    
    dir = os.listdir(file_dir)   #data的路径  
    for file in dir:
        dir_name = os.path.join(file_dir,file)
        dir_list.append(dir_name)    #No和Yes的绝对路径   No路径在前,Yes路径在后
    file1 = os.listdir(dir_list[0])  #No文件夹下的文件
    for i in file1:
        file_path = os.path,join(dir_list[0],i)
        No.append(file_path)
        labels_No.append(0)
    file2 = os.listdir(dir_list[1])
    for i in file2:
        file2_path = os.path.join(dir_list[1],i)
        Yes.append(file2_path)
        labels_Yes.append(1)
    image_list = np.hstack((Yes,No))
    
    labels_list = np.hstack((labels_Yes,labels_No))
    temp = np.array([image_list,labels_list])    #shape:  2*389
    temp = temp.transpose()   #389*2
    
    np.random.shrffle(temp)
    all_image_list = list(temp[:,0])   #image_path
    all_label_list = list(temp[:,1])
    
    all_label_list = [int(i) for i in all_label_list]
    length = len(all_label_list)
    
    n_test = int(math.ceil(length * ratio)
    n_train = length -  n_test
    
    tra_image = all_image_list[:n_train]
    tra_label = all_label_list[:n_train]

    test_image = all_image_list[n_train:]
    test_label = all_label_list[n_train:]

    train_data = [(tra_image[i],tra_label[i]) for i in range(len(tra_image))]
    test_dat   = [(test_image[i],test_label[i]) for i in range(len(test_image))]

    return train_data test_data    

以上的代码即是数据集的制作。接下来的代码是torch中的数据集类,作用就是加载训练和测试数据集。

class datasets(Dataset):
    
    def __init__(self,data,transform = None,test = False):
        imgs = []
        labels = []
        self.test = test
        self.transform = transform
        self.data = data
        self.len = len(data)
        for i in self.data:
            imgs.append(i[0])
            self.imgs = imgs
            labels.append(i[1])
            self.labels = labels
    def __getitem__(self,index):
        if self.test:
            filename = self.imgs[index]
            filename = filename
            img_path = self.imgs[index]
            img = cv2.imdecode(np.fromfile(img_path,dtype = np.uint8),-1) #-1是完整图像,0                                                
                                                                          #是灰度,1是彩图
            img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)    
            img = cv2.resize(img,(config_width,config_height))

            img = transforms.ToTensor()(img)
            return img,filename
        else:
            img_paht = self.imgs[index]
            labels = self.labels[index]
            img = cv2.imdecode(np.fromfile(img_path,dtype = np.uint8),-1)
            img = cv2.cvtColor(img,cv2.COLOR_GRAY3RGB)
            img = cv2.resize(img,(config.img_width,config.img_height))
            if self.transorm is not None:
                img = Image.fromarray(img)   #torch中的transforms操作的是image,不是array
                img = self.transform(img)            
            slse:
                img = transforms.ToTensor()(img)
            return img,label
    def __len__(self):
        return len(self.data)

def collate_fn(batch):   #将多个样本拼接成一个batch   这段代码在项目中没有运用。
    imgs = []
    labels = []
    for sample in batch:
        imgs.append(sample[0])
        labels.append(sample[1])
    return torch.stack(imgs,0),labels          
   

编写配置文件 config.py

class MyConfigs():
    data_folder = ''
    test_data_folder = ''
    model_name = 'resnet'
    weights = './checkpoints'
    logs = './logs'
    examples_folder = './example'
    epochs = 2000
    batch_size = 8
    img_height = 100
    img_width = 100
    num_classes = 2
    lr = 1e-2
    weight_decay = 2e-4
    ratio = 0.1
config = MyConfigs()
  

开始训练的代码 train.py

from test import *
from utils.utils import *
import torch
import matplotilb.pyplot as plt
import Model
from torch import nn,optim
import time
from time import srtftime,gmtime

np.set_printoptions(threshold = np.inf)

if __name__ == '__main__':
    print(torch.__version__)
    print(torch.cuda.is_available())

    # device_ids = [5, 6]

    # 1.创建文件夹
    if not os.path.exists(config.example_folder):
        os.mkdir(config.example_folder)
    if not os.path.exists(config.weights):
        os.mkdir(config.weights)
    if not os.path.exists(config.logs):
        os.mkdir(config.logs)

    # model = torch.nn.DataParallel(model,device_ids)
    model = Model.get_net()
    if torch.cuda.is_available():
        model = model.cuda()



    optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=config.weight_decay)

    # optimizer = optim.Adam(model.parameters(), lr=config.lr, betas=(0.9, 0.999), weight_decay=config.weight_decay)

    criterion = nn.CrossEntropyLoss().cuda()

    # start_epoch = 0
    current_F_Score = 0
    resume = True
    if resume:
        checkpoint = torch.load(config.weights + config.model_name + '.pth')
        start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])

    transform = transforms.Compose([
        # transforms.RandomResizedCrop(90),  #将PIL图像裁剪成任意大小和纵横比
        transforms.ColorJitter(0.05, 0.05, 0.05),  #随机改变图像的亮度对比度和饱和度
        transforms.RandomRotation(5,expand=True),    #随机旋转一定的角度,中心旋转
        # transforms.RandomGrayscale(p = 0.5),   #将图像以一定的概率转换为灰度图像
        #transforms.RandomHorizontalFlip(0.5),        #以0.5的概率水平翻转给定的图像
        #transforms.RandomVerticalFlip(0.5),  #以0.5的概率竖直翻转给定的图像
        transforms.Resize((config.img_width, config.img_height)),  # 把给定的图片resize到given size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])])  # 用均值和标准差归一化张量图像


    _, train_list = get_files(config.data_folder, config.ratio)
    input_data = datasets(train_list, transform=transform)

    train_loader = DataLoader(input_data, batch_size=config.batch_size, shuffle=True,
                              pin_memory=True, num_workers=1)
    test_list, _ = get_files(config.data_folder, config.ratio)
    test_loader = DataLoader(datasets(test_list, transform=None), batch_size=config.batch_size, shuffle=False,
                              num_workers=1)


    # optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=config.weight_decay)
    #学习率衰减
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5,last_epoch=-1)
    #描述:等间隔调整学习率,每次调整为 lr
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

枫桥夜泊1003

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值