【DL】imges2classIMG 生成Torch专用的图片文件目录

1 原文件目录

images
label

2 生成的文件

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3 代码

3.1 imges2classIMG.py

import os
import shutil
import cv2 as cv
'''
@Author:xhh
@Email:xhh0608@foxmail.com

@description:
we need to edit [saveRoot] & [floderName]

[in] exit dirs and files
root/
 |- images/n0153282900000005.jpg ...
 |- imges2classIMG.py 
 |- classIMG/ 

[out] create files it can be used to import in pyTorch # 文件夹名[label]/文件[jpg...]
root/
 |- images/n0153282900000005.jpg ...
 |- ---
 |- imges2classIMG.py 
 |- classIMG/ 
 |-		   |- n01532829/n0153282900000005.jpg \ ...jpg
 |-		   |- n01532820/n0153282000000005.jpg \ ...jpg

'''

root = 'images' # all images
saveRoot = 'classIMG'

if saveRoot not in os.listdir():
	os.mkdir(saveRoot)

name2labels = {} # ['n01532829':0, 'n04418357':1, ...]

# 按照一定规则排序
imageNameList = sorted(os.listdir(os.path.join(root)))

# 第几张照片
idx = 0
for imageName in imageNameList:
    
	floderName = imageName[:9]
	
    temp_label_num = 0
    if floderName not in name2labels.keys():
        # 保存为字典 name2labels
        name2labels[imageName[:9]] = temp_label_num
        temp_label_num += 1
        
        
        # 创建文件夹 ./classIMG/n01532829
        os.mkdir(saveRoot + '/' + floderName)
        
    # move img
    shutil.copyfile(root + '/' + imageName, saveRoot + '/' + floderName + '/' + imageName)
    idx += 1
    print('idx {0} copy {1} to {2}'.format(idx, root + '/' + imageName, saveRoot + '/' + floderName + '/' + imageName))

3.2 xhhDataset

import torch
import os, glob
import random, csv

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image


class XHHDataset(Dataset):

    def __init__(self, root, resize, mode='train'):
        '''

        :param root: 根目录 [such as classIMG}
        :param resize: =>224 x 224
        :param mode: train or val or None
        '''
        super(XHHDataset, self).__init__()

        self.root = root
        self.resize = resize

        self.name2lable = {} # 'sq...':0
        # sorted() 排序
        for name in sorted(os.listdir(os.path.join(root))):
            # 过滤非dir
            if not os.path.isdir(os.path.join(root, name)):
                continue

            self.name2lable[name] = len(self.name2lable.keys()) # {'n015544':0, ...}
        # print(self.name2lable)
        # images, labels
        self.images, self.labels = self.load_csv('images.csv')

        if mode == 'train': # 60%
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == "val": # 20%
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else: # 20%
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]



    # imgae_path, label
    def load_csv(self, filename):

        # 仅创建一次csv文件
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2lable.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            # 1167 , 'XHHDataset\\bulbasaur\\00000000.png'
            # 60000, 'classIMG\\n01532829\\n0153282900000005.jpg'
            print(len(images), images)

            # 保存images文件路径名
            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:
                    # 'XHHDataset\\bulbasaur\\00000000.png'
                    # os.sep  ==  '\\' 分割
                    name = img.split(os.sep)[-2]
                    label = self.name2lable[name]
                    # 'XHHDataset\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)


        # read from csv
        images, labels = [], []
        with open(os.path.join(self.root, filename), mode='r') as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row
                label = int(label)

                images.append(img)
                labels.append(label)
        assert len(images) == len(labels)

        return images, labels


    def __len__(self):
        # len => %60  or  20%  or  20%
        return len(self.images)

    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std  = [0.229, 0.224, 0.255]

        # Normalize ===>>>  x_hot = (x-mean)/std
        # x = x_hat*std + mean
        # x:[c, h, w]
        # mean:[3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)

        x = x_hat*std + mean
        return x

    def __getitem__(self, idx):
        # 处理一张idx图片的数据(引用传递)
        # idx ~ [0, len(images)]
        # self.images, self.labels
        # img:'XHHDataset\\bulbasaur\\00000000.png'
        # lable:0
        img, label = self.images[idx], self.labels[idx]

        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'), # string path => image data
            transforms.Resize((int(self.resize*1.3), int(self.resize*1.3))), # resize 1.25倍
            transforms.RandomRotation(15), # 15°内随机旋转
            transforms.CenterCrop(size=self.resize), # 裁剪
            transforms.ToTensor(), # ToTensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std =[0.229, 0.224, 0.255])
        ])

        img = tf(img)
        label = torch.tensor(label)

        return img, label


def main():
    # 可视化工具
    import visdom
    import time

    # >>> python -m visdom.server
    viz = visdom.Visdom()

    db = XHHDataset(root='classIMG', resize=224, mode='train')

    # test sample
    x, y = next(iter(db))
    print('sample:', x.shape, y.shape, y)

    viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample__x'))

    # 加载batchsz图片 num_workers 线程
    loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
    for x, y in loader:
        viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

        time.sleep(10)
        
if __name__ == '__main__':
    main()

3.3 train & val & test

import torch
import math
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader
import torch.nn.init as init # 初始化权重

from xhhDataset import XHHDataset
from xhhModels.resnet import ResNet18
from xhhModels.vgg16norm import VGG16
from xhhModels.mogaA import MoGaA
from xhhModels.newmogaA import NewMoGaA



n_class = 100
batchsz = 16
lr = 1e-3
epochs = 100

# >>> python -m visdom.server
usingVIZ = True

usingModelName = 'NewMoGaA'

printOnce = True
if printOnce:
    printOnce = False
    print('*********** use cuda is {} ************'.format(torch.cuda.is_available()))


# device = torch.device('cuda')
# 随机种子
torch.manual_seed(19960608)

### 读取 重点看这里--------------------------------------
train_db = XHHDataset('classIMG', 320, mode='train')
val_db   = XHHDataset('classIMG', 320, mode='val')
test_db  = XHHDataset('classIMG', 320, mode='test')

# num_workers 多线程
train_loader = DataLoader(train_db, batchsz, shuffle=True, num_workers=8)
val_loader   = DataLoader(val_db, batchsz, num_workers=4)
test_loader  = DataLoader(test_db, batchsz, num_workers=4)
### 读取 重点看这里--------------------------------------

if usingVIZ:
    # >>> python -m visdom.server
    viz = visdom.Visdom()

# 动态调整学习率
def adjust_learning_rate(optimizer, lr):
    '''
    :param optimizer: 优化器
    :param lr:学习率  lr*0.5
    :return:lr
    '''
    lr = lr * 0.5
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

## 初始化权重
def xavier(param):
    init.xavier_uniform_(param)

def weights_init(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
           n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
           m.weight.data.normal_(0, math.sqrt(2. / n))
           if m.bias is not None:
              m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):
           m.weight.data.fill_(1)
           m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
           n = m.weight.size(0)  # fan-out
           init_range = 1.0 / math.sqrt(n)
           m.weight.data.uniform_(-init_range, init_range)
           m.bias.data.zero_()


def evaluate(model, loader):
    total_num, total_correct = 0, 0
    for x, y in loader:

        # x.cuda()  or   y.to(device)    when-> device = torch.device('cuda')
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.cuda()

        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        total_correct += torch.eq(pred, y).sum().float().item()
        total_num += x.size(0)
    return total_correct / total_num



def main():

    if usingModelName == 'ResNet18':
        model = ResNet18(num_class=n_class)
    elif usingModelName == 'VGG16':
        model = VGG16(num_class=n_class)
    elif usingModelName == 'MoGaA':
        model = MoGaA(n_class = n_class)
    elif usingModelName == 'NewMoGaA':
        model = NewMoGaA(n_class = n_class)
    else:

        print('Model name', usingModelName, 'is error !')
        return None
    print('Model:', usingModelName, 'is created successful!')


    # 初始化合适的权重
    model.apply(weights_init)

    # model.cuda()
    if torch.cuda.is_available():
        model = model.cuda()


    optimizer = optim.Adam(model.parameters(), lr=lr)
    # loss_func
    criteon = nn.CrossEntropyLoss()

    best_acc, best_epoch = 0, 0
    val_acc_is_not_update = 0 # val_acc连续未更新的次数

    global_step = 0

    if usingVIZ:
        viz.line([0], [0], win='train_loss', opts=dict(title='train_loss')) # ([y], [x]...)
        viz.line([0], [0], win='val_acc', opts=dict(title='val_acc'))

    for epoch in range(epochs):
        for step, (x, y) in enumerate(train_loader):
            # x: [b, 3, 224, 224], y: [b]
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()

            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if usingVIZ:
                viz.line([loss.item()], [global_step], win='train_loss', update='append')

            global_step += 1
            print("epoch:{}---step{}---loss{}---lr{}".format(epoch, step, loss.item(), lr))

        if epoch % 2 == 0:
            val_acc = evaluate(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc # 0.9496
                val_acc_is_not_update = 0

                # torch.save(model.state_dict(), 'best_{}.ckpt'.format(epoch))
                # acc > 0.7 才开始保存
                if best_acc > 0.7:
                    torch.save(model, '{}_epoch{}_acc{}.pt'.format(usingModelName, epoch, round(best_acc, 4)))
            else:
                val_acc_is_not_update += 1

            # 6个epoch,即3次val未提升val_acc   lr = lr*0.5
            if val_acc_is_not_update == 3:
                lr = adjust_learning_rate(optimizer, lr)
                val_acc_is_not_update = 0
                print(f'************ epoch:{epoch}---update lr is:{lr} ************')

            if usingVIZ:
                viz.line([val_acc], [global_step], win='val_acc', update='append')

            print('best acc:', best_acc, 'best epoch:', best_epoch)


def finalTestModel(modelPT):
    '''
    :param modelPT: 完整模型.pt的路径
    :return:
    '''
    # model = ResNet18(num_class=n_class)

    # load Model.pt

    assert modelPT[-3:] == '.pt'
    model = torch.load(modelPT)
    if torch.cuda.is_available():
        model = model.cuda()

    print(f'loaded from {modelPT} is ok')

    test_acc = evaluate(model, test_loader)
    print('test acc:', test_acc)

if __name__ == '__main__':
    main()

    # finalTestModel('savePTS/resNet_best_4.pt')
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值