深度学习:使用timm库的resnet101模型对glaucoma图像集进行图像分类预测

"第一次训练是为了得出最好的best.mdl,用于第二次训练这样就可以用在更好的模型参数基础上进行训练"
import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader
from DIYdata_loader import DIYData_loader
import timm

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


batchsz = 32
picture_resize = 224
lr = 1e-4
epochs = 10
num_classes=3
device = torch.device('cuda')
torch.manual_seed(1234)

# 图片数据集路径&名字
picture_data_path = r'D:\python pycharm learning\清华大佬课程\青光眼ResNet\青光眼分类'
# 模型权重路径&名字
weight_data_path = r'D:\python pycharm learning\清华大佬课程\青光眼ResNet'
weight_data_name=  'resnet101_a1h-36d3f2aa.pth'
# best权重路径&名字
best_weight_path = r'D:\python pycharm learning\清华大佬课程\青光眼ResNet'
best_weight_name = 'best_glaucoma.mdl'

train_db =  DIYData_loader(picture_data_path, picture_resize, mode='train')
val_db   =  DIYData_loader(picture_data_path, picture_resize, mode='val')
test_db  =  DIYData_loader(picture_data_path, picture_resize, mode='test')

train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=8)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=8)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=8)

viz = visdom.Visdom()


def evaluate(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()

    return correct / total


def main():
    # 创建一个 ResNet101 模型,输出类别数为5
    model = timm.create_model('resnet101', pretrained=False, num_classes=num_classes)

    # 从本地文件加载预训练权重
    state_dict = torch.load(os.path.join(weight_data_path, weight_data_name))

    # 修改全连接层的权重和偏置,使其适应新的类别数
    state_dict['fc.weight'] = state_dict['fc.weight'][:num_classes, :]
    state_dict['fc.bias'] = state_dict['fc.bias'][:num_classes]

    # 加载权重到模型
    model.load_state_dict(state_dict, strict=False)

    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

    for epoch in range(epochs):
        model.train()
        for step, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        if epoch % 1 == 0:
            val_acc = evaluate(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                # 保存模型时
                torch.save(model.state_dict(), best_weight_name)

                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load(best_weight_name))
    print('loaded from ckpt!')

    test_acc = evaluate(model, test_loader)
    print('test acc:', test_acc)


if __name__ == '__main__':
    main()
"在第一次训练得出的最好模型参数基础上,进行第二次训练,进一步优化模型参数 "
import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader
from DIYdata_loader import DIYData_loader
import timm
import os
import first_train

batchsz = first_train.batchsz
picture_resize = first_train.picture_resize
lr = first_train.lr
epochs = 10
num_classes=first_train.num_classes
device = torch.device('cuda')
torch.manual_seed(1234)

# 使用first_train中定义的全局变量
picture_data_path =   first_train.picture_data_path
weight_data_path  =   first_train.weight_data_path

best_weight_path  =   first_train.best_weight_path
best_weight_name  =   first_train.best_weight_name



train_db =  DIYData_loader(picture_data_path, picture_resize, mode='train')
val_db   =  DIYData_loader(picture_data_path, picture_resize, mode='val')
test_db  =  DIYData_loader(picture_data_path, picture_resize, mode='test')

train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=8)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=8)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=8)

viz = visdom.Visdom()


def evaluate(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()

    return correct / total


def main():
    # 创建一个 ResNet101 模型,输出类别数为5
    model = timm.create_model('resnet101', pretrained=False, num_classes=num_classes)

    # 从本地文件加载预训练权重
    state_dict = torch.load(os.path.join(best_weight_path, best_weight_name))

    model.load_state_dict(state_dict)
    # 加载权重到模型

    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

    for epoch in range(epochs):
        model.train()
        for step, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        if epoch % 1 == 0:
            val_acc = evaluate(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), best_weight_name)
                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load(best_weight_name))
    print('loaded from ckpt!')

    test_acc = evaluate(model, test_loader)
    print('test acc:', test_acc)


if __name__ == '__main__':
    main()
# 数据集共有364张图:Glaucoma:32,Normal:225,Suspect Glaucoma:107
import torch
import os, glob
import random, csv
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms


# 自定义数据加载类
class DIYData_loader(Dataset):
    def __init__(self, root, resize, mode):  # root:文件所在目录,resize:图像分辨率调整一致,mode:当前类何功能
        super(DIYData_loader, self).__init__()

        self.root = root
        self.resize = resize

        self.name2label = {}  # 对每个加载的文件进行编码:'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4
        for name in sorted(os.listdir((os.path.join(root)))):  # 对指定root中的文件进行排序
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.name2label[name] = len(self.name2label.keys())  # keys返回列表当中的value,len计算列表长度
        #print(self.name2label)  # 根据文件顺序,以idx:文件名,vlaue:0,1,2,3,4,生成列表
        # images labels
        self.images, self.labels = self.load_csv('images.csv')  # load_csv要么先创建images.csv,要么直接读取images.csv,
        #print('data_len:',len(self.images))
        if mode == 'train':  # train dataset 60% of ALL DATA
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'validation':  # val dataset 60%-80% of ALL DATA
            self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # test dataset 80%-100% of ALL DATA
            self.images = self.images[int(0.8 * len(self.images)):int(len(self.images))]
            self.labels = self.labels[int(0.8 * len(self.labels)):int(len(self.labels))]
        # images[0]: D:\python pycharm learning\清华大佬课程\fisrt\pokemon\mewtwo\00000081.png
        # #labels[0]:2
        # images 还是图片的地址列表,需要__getitem__继续转换

    # image,label 不能把所有图片全部加载到内存,可能会爆内存
    def load_csv(self, filename):  # 生成,读取filename文件
        # filename 不存在:生成filename
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                # .../pokemen/mewtwo/00001.png 加载进images列表
                # 实际上是加载每张图片的地址
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            print(len(images), images[0])
            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:  # .....\bulbasaur\00000000.png
                    name = img.split(os.sep)[-2]  # 指:bulbasaur 图片真实类别
                    label = self.name2label[name]  # 在name2label列表根据name找出对应的value:0,1...
                    # .....\bulbasaur\00000000.png , 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # filename 存在:直接读取filename
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                # '...pokemon\bulbasaur\00000000.png', 0
                img, label = row
                label = int(label)

                images.append(img)
                labels.append(label)

        assert len(images) == len(labels)
        return images, labels

    def __len__(self):
        return len(self.images)

    def denormalize(self, x_hat):  # 对已经进行规范化处理的totensor,去除规范化
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        # x_hat = (x-mean)/std
        # x = x_hat*std +mean
        # x:[c,h,w]
        # mean:[3]=>[3,1,1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)

        x = x_hat * std + mean

        return x

    def __getitem__(self, idx):
        pass
        # idx~[0~len(images)]
        # self.iamges,self.labels
        # images[0]: D:\python pycharm learning\清华大佬课程\fisrt\pokemon\mewtwo\00000081.png
        # #labels[0]:2
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),  # string image => image data
            transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),  # 压缩到稍大
            transforms.RandomRotation(20),  # 图片旋转,增加图片的复杂度,但是又不会使网络太复杂
            transforms.CenterCrop(self.resize),  # 可能会有其他的底存在
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
            # R mean:0.854,std:0.229
        ])
        img = tf(img)
        label = torch.tensor(label)
        # Pokemon类根据一个索引每次返回一个img(三位张量),一个label(0维张量)
        return img, label  # img,label打包成元组返回


def main():
    import visdom  # 启动 python -m visdom.server,http://localhost:8097
    import time
    viz = visdom.Visdom()
    db = DIYData_loader('D:\python pycharm learning\清华大佬课程\second\青光眼分类', 224, 'train')
    # x,y = next(iter(db))
    # print('sample:',x.shape,y.shape,y)
    # viz.images(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
    # DataLoader加载器按batch_size打乱所有依次在内存当中按批次顺序加载每次批次,
    # 每个批次内含batch个Pokemon类返回的对象(元组,列表,字符串)
    loader = DataLoader(db, batch_size=20, shuffle=True)
    for x, y in loader:
        print('x_shape:',x.shape)

        viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

        time.sleep(10)

        break
if __name__ == '__main__':
    main()

 

 

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值