resnet+Attention项目完整代码学习

项目名称:CBAM.PyTorch-master

来源论文:CBAM: Convolutional Block Attention Module--CVPR2018

项目路径信息:

train.py

import os
from collections import OrderedDict
import argparse
import torch
import torch.nn as nn 
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import transforms, models, datasets
import matplotlib.pyplot as plt
from data_loader.ImageNet_datasets import ImageNetData
import model.resnet_cbam as resnet_cbam
from model.Medical import CovNet
from trainer.trainer import Trainer
from utils.logger import Logger
from PIL import Image
from torchnet.meter import ClassErrorMeter
from tensorboardX import SummaryWriter
import torch.backends.cudnn as cudnn

import warnings
warnings.filterwarnings("ignore")
resize=224
def load_state_dict(model_dir, is_multi_gpu):
    state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage)['state_dict']
    if is_multi_gpu:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]       # remove `module.`
            new_state_dict[name] = v
        return new_state_dict
    else:
        return state_dict

def main(args):

    if 0 == len(args.resume):
        logger = Logger('./logs/'+args.model+'.log')
    else:
        logger = Logger('./logs/'+args.model+'.log', True)


    logger.append(vars(args))

    if args.display:
        writer = SummaryWriter()
    else:
        writer = None
    gpus = args.gpu.split(',')
    data_transforms = {
        'train': transforms.Compose([
            # transforms.RandomResizedCrop(224),
            # transforms.RandomHorizontalFlip(),
            transforms.Resize((args.imagesize, args.imagesize)),
            # transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),

            # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((args.imagesize, args.imagesize)),
            transforms.ToTensor(),
            # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

    train_datasets = datasets.ImageFolder(os.path.join(args.data_root, 't256'), data_transforms['train'])

    val_datasets   = datasets.ImageFolder(os.path.join(args.data_root, 'v256'), data_transforms['val'])
    train_dataloaders = torch.utils.data.DataLoader(train_datasets, batch_size=args.batch_size*len(gpus), shuffle=True, num_workers=4)

    val_dataloaders   = torch.utils.data.DataLoader(val_datasets, batch_size=16, shuffle=True, num_workers=4)
    unloader = transforms.ToPILImage()
    if args.debug:
        x, y =next(iter(train_dataloaders))
        # image = x[0].squeeze(0)  # remove the fake batch dimension
        # image = unloader(image)
        # image.save('example.jpg')
        #
        plt.text(2, -20, "labels:" + str(y.numpy()), fontsize=15)
        grid_img = torchvision.utils.make_grid(x, nrow=8)
        plt.imshow(grid_img.permute(1, 2, 0))
        plt.show()

        print("x.shape",x.shape)
        # print(y.shape)
        # print("y",y)
        # logger.append([x, y])

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    is_use_cuda = torch.cuda.is_available()
    cudnn.benchmark = True

    if  'resnet50' == args.model:
        my_model = models.resnet50(pretrained=False)
        my_model.fc = nn.Linear(2048, 5)
    elif 'resnet18' == args.model:
        my_model = models.resnet18(pretrained=True)
        my_model.fc = nn.Linear(512, 5)
    elif 'resnet50_cbam' == args.model:
        my_model = resnet_cbam.resnet50_cbam(pretrained=True)
        my_model.fc = nn.Linear(2048, 5)
    elif 'resnet101_cbam' == args.model:
        my_model = resnet_cbam.resnet101_cbam(pretrained=True)
        my_model.fc = nn.Linear(2048, 2)
        # my_model.sfT = nn.Sigmoid()
    elif 'resnet101' == args.model:
        my_model = models.resnet101(pretrained=True)
        my_model.fc = nn.Linear(2048, 2)
        # my_model.sfT = nn.Sigmoid()
    elif 'resnet152_cbam' == args.model:
        my_model = resnet_cbam.resnet152_cbam(pretrained=True)
        my_model.fc = nn.Linear(2048, 2)
        # my_model.sfT = nn.Sigmoid()
    elif 'resnet152' == args.model:
        my_model = models.resnet152(pretrained=True)
        my_model.fc = nn.Linear(2048, 2)
        # my_model.sfT = nn.Sigmoid()
    elif 'vgg19' == args.model:
        my_model = models.vgg19(pretrained=True)
        my_model.fc = nn.Linear(1000, 5)
        # my_model.sfT = nn.Sigmoid()
    elif 'CovNet' == args.model.split('_')[0]:
        my_model=CovNet(5)


    else:
        raise ModuleNotFoundError

    #my_model.apply(fc_init)
    if is_use_cuda and 1 == len(gpus):
        my_model = my_model.cuda()
    elif is_use_cuda and 1 < len(gpus):
        my_model = nn.DataParallel(my_model.cuda())
    print(my_model)

    loss_fn = [nn.CrossEntropyLoss()]
    optimizer = optim.SGD(my_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    lr_schedule = lr_scheduler.MultiStepLR(optimizer, milestones=[20,40, 60], gamma=0.1)

    metric = [ClassErrorMeter([1,2], True)]

    start_epoch = 0
    num_epochs  = 50


    my_trainer = Trainer(my_model, args.model, loss_fn, optimizer, lr_schedule, 6, is_use_cuda, train_dataloaders, \
                        val_dataloaders, metric, start_epoch, num_epochs, args.debug, logger, writer)
    my_trainer.fit()
    # logger.append('Optimize Done!')


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='PyTorch Template')
    parser.add_argument('-r', '--resume', default='', type=str,
                        help='path to latest checkpoint (default: None)')
    parser.add_argument('--debug', action='store_true', default=True,dest='debug',
                        help='trainer debug flag')
    parser.add_argument('-g', '--gpu', default='0', type=str,
                        help='GPU ID Select')                    
    parser.add_argument('-d', '--data_root', default='./datasets',
                         type=str, help='data root')
    parser.add_argument('-t', '--train_file', default='./datasets/train.txt',
                         type=str, help='train file')
    parser.add_argument('-v', '--val_file', default='./datasets/val.txt',
                         type=str, help='validation file')
    parser.add_argument('-m', '--model', default='CovNet',
                         type=str, help='model type')
    parser.add_argument('--batch_size', default=32,
                         type=int, help='model train batch size')
    parser.add_argument('--display', action='store_true', dest='display',default=True,
                        help='Use TensorboardX to Display')
    parser.add_argument('--imagesize', default=224,
                        type=int, help='model train batch size')
    args = parser.parse_args()
    main(args)

test.py

import os
from collections import OrderedDict
from PIL import Image
import torch
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torchvision import transforms, models
from model import *
# import pretrainedmodels
import numpy as np
import model.resnet_cbam as resnet_cbam
#DATA_ROOT = './datasets/xuelang_round1_test_a_20180709'
#DATA_ROOT = './datasets/xuelang_round1_test_b'
DATA_ROOT = './datasets/xuelang_round2_test_a_20180809'
RESULT_FILE = 'result.csv'

import warnings
warnings.filterwarnings("ignore")
def test_and_generate_result(epoch_num, model_name='resnet101', img_size=320, is_multi_gpu=False):
    data_transform = transforms.Compose([
        transforms.Resize(img_size, Image.ANTIALIAS),
        transforms.ToTensor(),
        transforms.Normalize([0.53744068, 0.51462684, 0.52646497], [0.06178288, 0.05989952, 0.0618901])
    ])

    os.environ['CUDA_VISIBLE_DEVICES'] = '4'
    is_use_cuda = torch.cuda.is_available()

    if  'resnet152' == model_name.split('_')[0]:
        model_ft = models.resnet152(pretrained=True)
        my_model = resnet152.MyResNet152(model_ft)
        del model_ft
    elif 'resnet50' == model_name.split('_')[0]:
        model_ft = models.resnet50(pretrained=True)
        my_model = resnet50.MyResNet50(model_ft)
        del model_ft
    elif 'resnet101' == model_name.split('_')[0]:
        model_ft = models.resnet101(pretrained=True)
        my_model = resnet101.MyResNet101(model_ft)
        del model_ft
    elif 'densenet121' == model_name.split('_')[0]:
        model_ft = models.densenet121(pretrained=True)
        my_model = densenet121.MyDenseNet121(model_ft)
        del model_ft
    elif 'densenet169' == model_name.split('_')[0]:
        model_ft = models.densenet169(pretrained=True)
        my_model = densenet169.MyDenseNet169(model_ft)
        del model_ft
    elif 'densenet201' == model_name.split('_')[0]:
        model_ft = models.densenet201(pretrained=True)
        my_model = densenet201.MyDenseNet201(model_ft)
        del model_ft
    elif 'densenet161' == model_name.split('_')[0]:
        model_ft = models.densenet161(pretrained=True)
        my_model = densenet161.MyDenseNet161(model_ft)
        del model_ft
    elif 'ranet' == model_name.split('_')[0]:
        my_model = ranet.ResidualAttentionModel_92()
    elif 'senet154' == model_name.split('_')[0]:
        model_ft = pretrainedmodels.models.senet154(num_classes=1000, pretrained='imagenet')
        my_model = MySENet154(model_ft)
        del model_ft
    else:
        raise ModuleNotFoundError

    state_dict = torch.load('./checkpoint/' + model_name + '/Models_epoch_' + epoch_num + '.ckpt', map_location=lambda storage, loc: storage.cuda())['state_dict']
    if is_multi_gpu:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]       # remove `module.`
            new_state_dict[name] = v
        my_model.load_state_dict(new_state_dict)
    else:
        my_model.load_state_dict(state_dict)

    if is_use_cuda:
        my_model = my_model.cuda()
    my_model.eval()

    with open(os.path.join('checkpoint', model_name, model_name+'_'+str(img_size)+'_'+RESULT_FILE), 'w', encoding='utf-8') as fd:
        fd.write('filename|defect,probability\n')
        test_files_list = os.listdir(DATA_ROOT)
        for _file in test_files_list:
            file_name = _file
            if '.jpg' not in file_name:
                continue
            file_path = os.path.join(DATA_ROOT, file_name)
            img_tensor = data_transform(Image.open(file_path).convert('RGB')).unsqueeze(0)
            if is_use_cuda:
                img_tensor = Variable(img_tensor.cuda(), volatile=True)
            output = F.softmax(my_model(img_tensor), dim=1)
            defect_prob = round(output.data[0, 1], 6)
            if defect_prob == 0.:
                defect_prob = 0.000001
            elif defect_prob == 1.:
                defect_prob = 0.999999
            target_str = '%s,%.6f\n' % (file_name, defect_prob)
            fd.write(target_str)

def test_and_generate_result_round2(epoch_num, model_name='resnet101', img_size=224, is_multi_gpu=False):
    data_transform = transforms.Compose([
        transforms.Resize((img_size,img_size),Image.ANTIALIAS),
        transforms.ToTensor(),
        # transforms.Normalize([0.53744068, 0.51462684, 0.52646497], [0.06178288, 0.05989952, 0.0618901])
    ])

    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    is_use_cuda = torch.cuda.is_available()

    print(epoch_num)
    print(model_name)
    print(img_size)
    print(is_multi_gpu)
    #
    # if  'resnet152' == model_name.split('_')[0]:
    #     model_ft = models.resnet152(pretrained=True)
    #     my_model = resnet152.MyResNet152(model_ft)
    #     del model_ft
    # elif 'resnet152-r2' == model_name.split('_')[0]:
    #     model_ft = models.resnet152(pretrained=True)
    #     my_model = resnet152.MyResNet152_Round2(model_ft)
    #     del model_ft
    # elif 'resnet152-r2-2o' == model_name.split('_')[0]:
    #     model_ft = models.resnet152(pretrained=True)
    #     my_model = resnet152.MyResNet152_Round2_2out(model_ft)
    #     del model_ft
    # elif 'resnet152-r2-2o-gmp' == model_name.split('_')[0]:
    #     model_ft = models.resnet152(pretrained=True)
    #     my_model = resnet152.MyResNet152_Round2_2out_GMP(model_ft)
    #     del model_ft
    # elif 'resnet152-r2-hm-r1' == model_name.split('_')[0]:
    #     model_ft = models.resnet152(pretrained=True)
    #     my_model = resnet152.MyResNet152_Round2_HM_round1(model_ft)
    #     del model_ft
    # elif 'resnet50' == model_name.split('_')[0]:
    #     model_ft = models.resnet50(pretrained=True)
    #     my_model = resnet50.MyResNet50(model_ft)
    #     del model_ft
    # elif 'resnet101' == model_name.split('_')[0]:
    #     model_ft = models.resnet101(pretrained=True)
    #     my_model = resnet101.MyResNet101(model_ft)
    #     del model_ft
    # elif 'densenet121' == model_name.split('_')[0]:
    #     model_ft = models.densenet121(pretrained=True)
    #     my_model = densenet121.MyDenseNet121(model_ft)
    #     del model_ft
    # elif 'densenet169' == model_name.split('_')[0]:
    #     model_ft = models.densenet169(pretrained=True)
    #     my_model = densenet169.MyDenseNet169(model_ft)
    #     del model_ft
    # elif 'densenet201' == model_name.split('_')[0]:
    #     model_ft = models.densenet201(pretrained=True)
    #     my_model = densenet201.MyDenseNet201(model_ft)
    #     del model_ft
    # elif 'densenet161' == model_name.split('_')[0]:
    #     model_ft = models.densenet161(pretrained=True)
    #     my_model = densenet161.MyDenseNet161(model_ft)
    #     del model_ft
    # elif 'ranet' == model_name.split('_')[0]:
    #     my_model = ranet.ResidualAttentionModel_92()
    # elif 'senet154' == model_name.split('_')[0]:
    #     model_ft = pretrainedmodels.models.senet154(num_classes=1000, pretrained='imagenet')
    #     my_model = MySENet154(model_ft)
    #     del model_ft
    # else:
    #     raise ModuleNotFoundError

    if  'resnet50' == model_name.split('_')[0]:
        my_model = models.resnet50(pretrained=False)
    elif 'resnet50-cbam' ==  model_name.split('_')[0]:
        my_model = resnet_cbam.resnet50_cbam(pretrained=False)
    elif 'resnet101' == model_name.split('_')[0]:
        my_model = models.resnet101(pretrained=True)
        my_model.fc = nn.Linear(2048, 2)
        # my_model.sfT = nn.Sigmoid()
    else:
        raise ModuleNotFoundError

    print('./checkpoint/' + model_name + '/Models_epoch_' + epoch_num + '.ckpt')
    state_dict = torch.load('./checkpoint/' + model_name + '/Models_epoch_' + epoch_num + '.ckpt', map_location=lambda storage, loc: storage.cuda())['state_dict']


    if is_multi_gpu:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]       # remove `module.`
            new_state_dict[name] = v
        my_model.load_state_dict(new_state_dict)
    else:
        my_model.load_state_dict(state_dict)

    if is_use_cuda:
        my_model = my_model.cuda()
    my_model.eval()


    with open(os.path.join('checkpoint', model_name, model_name+'_'+str(img_size)+'_'+RESULT_FILE), 'w', encoding='utf-8') as fd:
        print("566", is_multi_gpu)

        fd.write('filename|defect,probability\n')
        test_files_list = os.listdir(DATA_ROOT)
        print("566", test_files_list)
        ii=0
        for _file in test_files_list:
            # print("566")

            file_name = _file
            # if '.jpg' not in file_name:
            #     continue
            file_path = os.path.join(DATA_ROOT, file_name)
            print(ii)
            ii += 1
            # print("5667",file_path)
            img_tensor = data_transform(Image.open(file_path).convert('RGB')).unsqueeze(0)
            # print("5667",img_tensor)

            if is_use_cuda:
                img_tensor = Variable(img_tensor.cuda(), volatile=True)
            # _, output, _ = my_model(img_tensor)
            print( ":", img_tensor.shape)
            output = my_model(img_tensor)
            print( "2222222222222:", output.data)
            output = F.softmax(output, dim=1)

            print( "33333333333333:", output.data[0, 0])


            for k in range(2):
                # print(k,":",output.data)
                print("44444444444:", output.data[0, k])
                defect_prob =np.round(output.data[0, k].cpu().numpy(), 6)
                print("np.round:", defect_prob)
                if defect_prob == 0.:
                    defect_prob = 0.000001
                elif defect_prob == 1.:
                    defect_prob = 0.999999
                target_str = '%s,%.6f\n' % (file_name + '|' + ('norm' if 0 == k else 'defect_'+str(k)), defect_prob)
                print("target_str:",target_str)
                fd.write(target_str)

if __name__ == '__main__':
    #test_and_generate_result('10', 'resnet152_2018073100', 416, True)
    #test_and_generate_result('2', 'resnet50_2018072500', 416, True)
    #test_and_generate_result('7','resnet101_2018072600', 416, True)
    #test_and_generate_result_round2('14','resnet152-r2-2o-gmp_2018081600', 600, True)
    #test_and_generate_result_round2('14', 'resnet152-r2-2o_2018081300', 600, True)
    #test_and_generate_result('12', 'densenet161_new_stra', 352, True)
    #test_and_generate_result('25', 'ranet_2018072400', 416, True)
    #test_and_generate_result('8', 'senet154_2018072500', 416, True)
    # test_and_generate_result_round2('9','resnet152-r2-hm-r1_2018082000', 576, True)
    test_and_generate_result_round2('9','resnet101', 224, False)

loger.py

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import traceback

class Logger(object):
    '''Save training process to log file with simple plot function.'''
    def __init__(self, fpath,resume=False): 
        self.file = None
        self.resume = resume

        if os.path.isfile(fpath):
            if resume:
                self.file = open(fpath, 'a') 
            else:
                self.file = open(fpath, 'w')
        else:
            self.file = open(fpath, 'w')



    def append(self, target_str):

        if not isinstance(target_str, str):

            try:
                target_str = str(target_str)
            except:
                traceback.print_exc()
            else:
                # print(self.file)
                # print(target_str)
                self.file.write(target_str + '\n')
                self.file.flush()
        else:


            self.file.write(target_str + '\n')
            self.file.flush()

    def close(self):
        if self.file is not None:
            self.file.close()

train.py

import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import time
import sys
import os
import torchvision
import matplotlib.pyplot as plt

class Trainer():
    def __init__(self, model, model_type, loss_fn, optimizer, lr_schedule, log_batchs, is_use_cuda, train_data_loader, \
                valid_data_loader=None, metric=None, start_epoch=0, num_epochs=25, is_debug=False, logger=None, writer=None):
        self.model = model
        self.model_type = model_type
        self.loss_fn  = loss_fn
        self.optimizer = optimizer
        self.lr_schedule = lr_schedule
        self.log_batchs = log_batchs
        self.is_use_cuda = is_use_cuda
        self.train_data_loader = train_data_loader
        self.valid_data_loader = valid_data_loader
        self.metric = metric
        self.start_epoch = start_epoch
        self.num_epochs = num_epochs
        self.is_debug = is_debug

        self.cur_epoch = start_epoch
        self.best_acc = 0.
        self.best_loss = sys.float_info.max
        self.logger = logger
        self.writer = writer
        self.global_step=0
    def fit(self):


        for epoch in range(0, self.start_epoch):
            self.lr_schedule.step()
        for epoch in range(self.start_epoch, self.num_epochs):

            self.logger.append('Epoch {}/{}'.format(epoch, self.num_epochs - 1))
            self.logger.append('-' * 60)
            self.cur_epoch = epoch

            # print(self.optimizer.state_dict()['param_groups'][0]['lr'])
            if self.is_debug:
                self._dump_infos()

            self._train()

            self.lr_schedule.step()
            self._valid()
            self._save_best_model()
        #     print()

    def _dump_infos(self):
        self.logger.append('---------------------Current Parameters---------------------')
        self.logger.append('is use GPU: ' + ('True' if self.is_use_cuda else 'False'))
        self.logger.append('lr: %f' % (self.lr_schedule.get_lr()[0]))
        self.logger.append('model_type: %s' % (self.model_type))
        self.logger.append('current epoch: %d' % (self.cur_epoch))
        self.logger.append('best accuracy: %f' % (self.best_acc))
        self.logger.append('best loss: %f' % (self.best_loss))
        self.logger.append('------------------------------------------------------------')

    def _train(self):
        self.model.train()  # Set model to training mode
        losses = []

        if self.metric is not None:
            # print("self.metric11",self.metric)
            # print("self.metric12",self.metric[0])
            self.metric[0].reset()
        print("self.train_data_loader.len()",len(self.train_data_loader))
        for i, (inputs, labels) in enumerate(self.train_data_loader):              # Notice



            self.writer.add_image("label:"+str(labels[0]), inputs[0], global_step=i, walltime=None, dataformats='CHW')
            if self.is_use_cuda:
                inputs, labels = inputs.cuda(), labels.cuda()
                labels = labels.squeeze()

            else:
                labels = labels.squeeze()

            self.optimizer.zero_grad()
            outputs = self.model(inputs)            # Notice
            # print("outputs.shape",outputs.shape)
            # print("labels.shape",labels.shape)
            # print("labels",labels)
            # print("outputs :",outputs )
            # print("prob :",prob )
            # print("pass:",torch.argmax(outputs,1))


            # plt.text(2, -20, "labels:" + str(labels.cpu().numpy()), fontsize=15)
            # grid_img = torchvision.utils.make_grid(inputs.cpu(), nrow=8)
            # plt.imshow(grid_img.permute(1, 2, 0))
            # plt.title("TEST")
            # plt.show()

            loss = self.loss_fn[0](outputs, labels)
            if i%10==0:
                print("epoch:{},iter:{}, loss:{}".format(self.cur_epoch,i,loss.item()))

            if self.metric is not None:

                # print("outputsoutputs", outputs)
                prob     = F.softmax(outputs, dim=1).data.cpu()
                # print("probprobprobprob",prob)
                # print("probprobprobprob",labels)
                self.metric[0].add(prob, labels.data.cpu())

            loss.backward()
            self.optimizer.step()

            losses.append(loss.item())       # Notice
            # print("0 == i % self.log_batchs0 == i % self.log_batchs",0 == i % self.log_batchs)

            if 0 == i % self.log_batchs or (i == len(self.train_data_loader) - 1):

                local_time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))

                batch_mean_loss  = np.mean(losses)

                print_str = '[%s]\tTraining Batch[%d/%d]\t Class Loss: %.4f\t'           \
                            % (local_time_str, i, len(self.train_data_loader) - 1, batch_mean_loss)


                if i == len(self.train_data_loader) - 1 and self.metric is not None:
                    top1_acc_score = self.metric[0].value()[0]
                    top5_acc_score = self.metric[0].value()[1]
                    print_str += '@Top-1 Score: %.4f\t' % (top1_acc_score)
                    print_str += '@Top-5 Score: %.4f\t' % (top5_acc_score)

                self.logger.append(print_str)
                self.writer.add_scalar('loss/loss_c', batch_mean_loss, self.global_step)
                self.global_step+=1

    def _valid(self):
        self.model.eval()
        losses = []
        acc_rate = 0.
        if self.metric is not None:
            self.metric[0].reset()

        with torch.no_grad():              # Notice
            for i, (inputs, labels) in enumerate(self.valid_data_loader):
                if self.is_use_cuda:
                    inputs, labels = inputs.cuda(), labels.cuda()
                    labels = labels.squeeze()
                else:
                    labels = labels.squeeze()

                outputs = self.model(inputs)            # Notice 
                loss = self.loss_fn[0](outputs, labels)

                if self.metric is not None:
                    prob     = F.softmax(outputs, dim=1).data.cpu()
                    # print("abels :", labels)
                    # print("outputs :",outputs )
                    # print("prob :",prob )
                    # print("pass:",torch.argmax(prob,1))
                    self.metric[0].add(prob, labels.data.cpu())
                    # print("self.metric[0].value():",self.metric[0].value())
                losses.append(loss.item())
            
        local_time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
        #self.logger.append(losses)
        batch_mean_loss = np.mean(losses)
        print_str = '[%s]\tValidation: \t Class Loss: %.4f\t'     \
                    % (local_time_str, batch_mean_loss)
        if self.metric is not None:
            top1_acc_score = self.metric[0].value()[0]
            top5_acc_score = self.metric[0].value()[1]
            print_str += '@Top-1 Score: %.4f\t' % (top1_acc_score)
            print_str += '@Top-5 Score: %.4f\t' % (top5_acc_score)
        self.logger.append(print_str)
        print("cur_epoch:",self.cur_epoch,"top1_acc_s:",top1_acc_score,"best_acc:",self.best_acc,"batch_mean_loss:",batch_mean_loss,"best_loss",self.best_loss)
        if top1_acc_score >= self.best_acc:
            self.best_acc = top1_acc_score
            self.best_loss = batch_mean_loss

    def _save_best_model(self):
        # Save Model
        self.logger.append('Saving Model...')
        state = {
            'state_dict': self.model.state_dict(),
            'best_acc': self.best_acc,
            'cur_epoch': self.cur_epoch,
            'num_epochs': self.num_epochs
        }
        if not os.path.isdir('./checkpoint/' + self.model_type):
            os.makedirs('./checkpoint/' + self.model_type)
        torch.save(state, './checkpoint/' + self.model_type + '/Models' + '_epoch_%d' % self.cur_epoch + '.ckpt')   # Notice

model


Medical.py

import torch,cv2
import os,glob
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import torchvision
from    torch.nn import functional as F


try:
    from skimage import data_dir
    from skimage import io
    from skimage import color
    from skimage import img_as_float,transform
    from skimage.transform import resize
except ImportError:
    raise ImportError("This example requires scikit-image")

class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        out=x.view(-1, shape)
        return out


class ConBlk(nn.Module):

    def __init__(self):
        super(ConBlk,self).__init__()
        self.conv1 = nn.Conv2d(3, 36, kernel_size=3, stride=2, padding=1)
        self.pool1=nn.MaxPool2d(2,2)
        self.bn1 = nn.BatchNorm2d(36)
        self.conv2 = nn.Conv2d(36, 36, kernel_size=3, stride=2, padding=1)
        self.pool2 = nn.MaxPool2d(2,2)
        self.bn2 = nn.BatchNorm2d(36)
        self.conv3 = nn.Conv2d(36, 36, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(2,2)
        # self.bn3 = nn.BatchNorm2d(36)
        self.fla=Flatten()

    def forward(self,x):
        x1=F.relu(self.bn1(self.pool1(self.conv1(x))))
        # print("x1.shape",x1.shape)
        x2=F.relu(self.bn2(self.pool2(self.conv2(x1))))
        # print("x2.shape", x2.shape)

        x3=F.relu(self.pool3(self.conv3(x2)))
        # print("x3.shape", x3.shape)
        out=self.fla(x3)
        return out

class CovNet(nn.Module):
    def __init__(self,num_class=2):
        super(CovNet, self).__init__()
        self.blk1=ConBlk()
        self.outlayer = nn.Sequential(
            nn.Linear(1764, 1024),
            nn.Dropout(0.5),
            nn.Linear(1024, num_class),
        )

    def forward(self, x):

        out=self.blk1(x)
        # print("out.shape:::",out.shape)
        out=self.outlayer(out)
        return out

resnet_cbam.py

import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo


__all__ = ['ResNet', 'resnet18_cbam', 'resnet34_cbam', 'resnet50_cbam', 'resnet101_cbam',
           'resnet152_cbam']


model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)

        self.ca = ChannelAttention(planes * 4)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet18_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet34_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet50_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet101_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet101'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model


def resnet152_cbam(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet152'])
        now_state_dict        = model.state_dict()
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict)
    return model

 

  • 10
    点赞
  • 117
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 10
    评论
以下是使用ResNet和ASPP模块进行图像语义分割的代码示例。 首先,我们需要导入所需的库和模块: ```python import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import resnet class ASPP(nn.Module): def __init__(self, in_channels, out_channels=256, rates=[6, 12, 18]): super(ASPP, self).__init__() self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1) self.conv_3x3_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=rates[0], padding=rates[0]) self.conv_3x3_2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=rates[1], padding=rates[1]) self.conv_3x3_3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=rates[2], padding=rates[2]) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_1x1_out = nn.Conv2d(in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1) self.bn_out = nn.BatchNorm2d(out_channels) def forward(self, x): size = x.shape[-2:] out_1x1 = self.conv_1x1(x) out_3x3_1 = self.conv_3x3_1(x) out_3x3_2 = self.conv_3x3_2(x) out_3x3_3 = self.conv_3x3_3(x) out_avg = self.avg_pool(x) out_avg = F.interpolate(out_avg, size=size, mode='bilinear', align_corners=True) out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_avg], dim=1) out = self.conv_1x1_out(out) out = self.bn_out(out) out = F.relu(out) return out class ResNetASPP(nn.Module): def __init__(self, num_classes): super(ResNetASPP, self).__init__() self.resnet = resnet.resnet101(pretrained=True) self.aspp = ASPP(in_channels=2048, out_channels=256) self.conv1 = nn.Conv2d(2048, 256, kernel_size=1) self.bn1 = nn.BatchNorm2d(256) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout(p=0.5) self.conv2 = nn.Conv2d(256, num_classes, kernel_size=1) def forward(self, x): size = x.shape[-2:] x = self.resnet.conv1(x) x = self.resnet.bn1(x) x = self.resnet.relu(x) x = self.resnet.maxpool(x) x = self.resnet.layer1(x) x = self.resnet.layer2(x) x = self.resnet.layer3(x) x = self.resnet.layer4(x) x = self.aspp(x) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.dropout(x) x = self.conv2(x) x = F.interpolate(x, size=size, mode='bilinear', align_corners=True) return x ``` 这里我们定义了一个`ASPP`类,其中包括三个不同空洞率的卷积层和一个全局平均池化层,然后将它们合并在一起。接下来,我们定义了一个`ResNetASPP`类,它使用ResNet作为特征提取器,并在最后添加了一个ASPP模块。 在训练模型时,我们可以使用以下代码: ```python model = ResNetASPP(num_classes=2) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() for epoch in range(num_epochs): for images, labels in train_loader: optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() ``` 这里我们定义了一个`ResNetASPP`对象并使用Adam优化器和交叉熵损失函数训练它。在每个epoch中,我们遍历训练数据集并计算输出和损失,然后使用反向传播和优化器来更新模型参数。 请注意,这只是一个基本的实现示例,您可能需要根据您的具体需求进行更改。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

东城西阙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值