Pytorch实现猫狗大战(三)

定义网络模型: network.py

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 23 22:15:15 2019

@author: ZQQ

network.py:存放网络模型

"""

#from torch import nn
import torch.nn as nn
import torchvision.models as models # 导入预训练模型(训练好的)

class feature_net(nn.Module):
    def __init__(self,model,dim,n_classes):
        super(feature_net,self).__init__()
        if model == 'vgg19':
            vgg19 = models.vgg19(pretrained=True)
            self.feature = nn.Sequential(*list(vgg19.children())[:-1])
            self.feature.add_module('gloabl average',nn.AvgPool2d(9))
        elif model == 'inception_v3':
            inception_v3 = models.inception_v3(pretrained=True)
            self.feature = nn.Sequential(*list(inception_v3.children())[:-1])
            self.feature._modules.pop('13')
            self.feature.add_module('global average',nn.AvgPool2d(35))
        elif model == 'resnet152':
            resnet152 = models.resnet152(pretrained=True)
            self.feature = nn.Sequential(*list(resnet152.children())[:-1])
    
        self.classifier = nn.Sequential(nn.Linear(dim,4096),
                                   nn.ReLU(inplace=True),
                                   nn.Dropout(0.5),
                                   nn.Linear(4096,4096),
                                   nn.ReLU(inplace=True),
                                   nn.Dropout(0.5),
                                   nn.Linear(4096,n_classes)
                                   )
        
    def forward(self,x):
            x = self.feature(x)
            x = x.view(x.size(0),-1)
            x = self.classifier(x)
            return x
  
# 查看其中一个模型结构,训练好的模型需要下载      
#model = feature_net('vgg19',10,2)
#print(model)    
          

运行并测试,run.py

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 23 22:46:21 2019

@author: ZQQ

参考:https://blog.csdn.net/qq_36556893/article/details/88943162
"""

import torch
from torch.autograd import Variable
import torchvision
from torchvision import datasets,transforms,models
import matplotlib.pyplot as plt
import numpy as np
import os
import time

import argparse # 命令行参数解析模块

from tensorboardX import SummaryWriter

from network import feature_net

# 参数设置
parser = argparse.ArgumentParser(description='cifar10')
parser.add_argument('--pre_epoch',default=0,help='begin epoch')
parser.add_argument('--total_epoch',default=1,help='time for ergodic') 
parser.add_argument('--model',default='vgg19',help='model for training')
parser.add_argument('--outf',default='./model',help='folder to output images and checkpoints') # 输出结果保存路径
parser.add_argument('--pre_model',default=False,help='use pre_model') # 恢复训练时模型的路径
args = parser.parse_args()

# 定义使用模型
model = args.model

# 如果有gpu资源使用gpu,否则使用cpu
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# 图片导入
path = 'data'

# 图片预处理操作组合在一起
transform = transforms.Compose([transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])

data_image = {x:datasets.ImageFolder(root = os.path.join(path,x),transform = transform) for x in ["train","val"]}


data_loader_image = {x:torch.utils.data.DataLoader(dataset=data_image[x],
                                                   batch_size = 4,
                                                   shuffle = True) for x in ["train","val"]}

classes = data_image["train"].classes # 按文件夹名字分类
class_index = data_image["train"].class_to_idx # 文件夹类名所对应的链值
print(classes) # 打印类别
print(class_index)

# 打印训练集,验证集大小
print("train data set:",len(data_image["train"]))
print("val data set:",len(data_image["val"]))

image_train,label_train = next(iter(data_loader_image["train"]))
mean = [0.5,0.5,0.5]
std = [0.5,0.5,0.5]
img = torchvision.utils.make_grid(image_train) # 把batch_size张图片拼成一个图片
print(img.shape) # torch.Size([3,228,906])
img = img.numpy().transpose((1,2,0)) # 本来是(0,1,2),相当于把第一维变为第三维,其他两维前移
print(img.shape)
img = img*std + mean # (228,906,3)范围由(-1,1)变为(0,1)

print([classes[i] for i in label_train]) # 打印image_train图像中对应的label_train,也就是图像的类型
plt.imshow(img) # 显示数据归一化到0到1的图像
#plt.show()

# 构建网络
use_model = feature_net(model,dim=512,n_classes=2)
for parma in use_model.feature.parameters():
    parma.requires_grad = False
    
for index,parma in enumerate(use_model.classifier.parameters()):
    if index == 6:
        parma.requires_grad = True
        
if use_cuda:
    use_model = use_model.to(device)
    
# 定义损失函数(代价函数)
loss = torch.nn.CrossEntropyLoss()

# 定义优化器
optimizer = torch.optim.Adam(use_model.classifier.parameters())

print(use_model)

# 使用预训练模型
if args.pre_model:
    print("Resume from checkpoint...")
    assert os.path.isdir('checkpoint','Error:no checkpoint directory found')
    state = torch.load('./checkpoint/ckpt.t7')
    use_model.load_state_dict(state['state_dict'])
    best_test_acc = state['acc']
    pre_epoch = state['epoch']

else:
    # 定义最优化的测试准确率
    best_test_acc = 0
    pre_epoch = args.pre_epoch
    
if __name__ == '__main__':
    total_epoch = args.total_epoch
    writer = SummaryWriter(logdir='./log')
    print("Start Training,{}...".format(model))
    with open("acc.txt","w") as acc_f:
        with open("log.txt","w") as log_f:
            start_time = time.time()
            
            for epoch in range(pre_epoch,total_epoch):
                
                print("epoch{}/{}".format(epoch,total_epoch))
                print("_"*10)
                # 开始训练
                sum_loss = 0.0
                accuracy = 0.0
                total = 0
                for i, data in enumerate(data_loader_image["train"]):
                    image,label = data
                    if use_cuda:
                        image,label = Variable(image.to(device)),Variable(label.to(device))
                    else:
                        image,label = Variable(image),Variable(label)
                        
                    # 前向传播
                    label_prediction = use_model(image)
                    
                    _,prediction = torch.max(label_prediction.data,1)
                    total += label.size(0)
                    current_loss = loss(label_prediction,label)
                    # 后向传播
                    optimizer.zero_grad()
                    current_loss.backward()
                    optimizer.step()
                    
                    sum_loss += current_loss.item()
                    accuracy += torch.sum(prediction == label.data)
                    
                    if total % 5 ==0:
                        print("total {},train loss:{:.4f},train accuracy:{:.4f}".format(total,sum_loss/total,100*accuracy/total))
                        # 写入日志
                        log_f.write("total {},train loss:{:.4f},train accuracy:{:.4f}".format(total,sum_loss/total,100*accuracy/total))
                        log_f.write('\n')
                        log_f.flush()
                        
                # 写入tensorboard
                writer.add_scalar('loss/train',sum_loss / (i+1),epoch)
                writer.add_scalar('accuracy/train',100.*accuracy / total ,epoch)
                # 每一个epoch测试准确率
                print("waiting for test...")
                # 在上下文环境中切断梯度计算,在此模式下,每一步的计算结果中requires_grad都是False,即使input设置为requires_grad = True
                # 固定卷积层的参数,只更新全连接层的参数
                with torch.no_grad():
                    accuracy = 0
                    total = 0
                    for data in data_loader_image["val"]:
                        use_model.eval()
                        image,label = data
                        if use_cuda:
                            image,label = Variable(image.to(device)),Variable(label.to(device))
                        else:
                            image,label = Variable(image),Variable(label)
                            
                        label_prediction = use_model(image)
                        _,prediction = torch.max(label_prediction.data,1)
                        total += label.size(0)
                        accuracy += torch.sum(prediction==label.data)
                        
                    # 输出测试准确率
                    #print('测试准确率为: %.3f%%' % (100*accuracy / total))
                    #print('test accuracy: %.3f%%' % (100*accuracy / total)) # 服务器上无法识别中文
                    print("test accuracy:{:.4f}%".format(100*accuracy/total))
                    acc = 100.*accuracy / total
                    
                    # 写入tensorboard
                    writer.add_scalar('accuracy/test',acc,epoch)
                    
                    # 将测试结果写入文件
                    print('saing model...')
                    torch.save(use_model.state_dict(),'%s/net_%3d.pth' % (args.outf,epoch + 1))
                    acc_f.write("epoch = %03d,accuracy = %.3f%%" % (epoch + 1,acc))
                    acc_f.write('\n')
                    acc_f.flush()
                    
                    # 记录最佳的测试准确率
                    if acc > best_test_acc:
                        print('saving best model...')
                        # 存储状态
                        state = {'state_dict':use_model.state_dict(),
                                 'acc':acc,
                                 'epoch':epoch +1, }
                        
                        # 没有就创建checkpoint文件夹
                        if not os.path.isdir('checkpoint'):
                            os.mkdir('checkpoint')
                            
                            torch.save(state,'./checkpoint/ckpt.t7')
                            best_test_acc = acc
                            # 写入tensorboard
                            writer.add_scalar('best_accuracy/test',best_test_acc,epoch)
                            
    end_time = time.time - start_time
    print("training time is: {:.0f}m {:.0f}s ".format(end_time // 60,end_time % 60))
    writer.close()
                    
                    

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

机器不学习我学习

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值