PyTorch 分类任务训练模板

简介

想用PyTorch 做分类任务的模型训练,找到一个使用模板,稍加调整并附上我的理解。

1. 数据准备

在这个阶段,传入batch_size, 传入训练样本的存储路径(image_path),数据储存格式如下:

Data
   ----class1
        -----image01.png
        -----image02.png
        ……
    ----class2
        -----image11.png
        -----image12.png
        ……     
      ----class3
        -----image21.png
        -----image22.png
        ……               

接下来就采用torch.utils.data.DataLoader将数据按照train 和 val 打包(这个函数的用法放在最后), 同时也使用了数据增强。

# 传入 batch_size
def train_val_data_process(batch_size:Int,image_path:str):      
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}    
                                   
    # check the image_path exist or not
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)    
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),                                         transform=data_transform["train"])
    train_num = len(train_dataset)
    
    cl_list = train_dataset.class_to_idx
    num_classes = len(cl_list)
    print("Number of classes:", num_classes) # 
    
    cla_dict = dict((val, key) for key, val in cl_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers # Here, the nw depends on your own coputer ability.
    print('Using {} dataloader workers every process'.format(nw))

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, 
                                               shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    val_dataloader  = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, 
                                                  shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for alidation.".format(train_num, val_num))         
    return train_dataloader, val_dataloader

2. 模型训练

def train_model_process(model,train_dataloader,val_dataloader,num_epochs):    
    # use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    
    # define the optimizer and loss fuction
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # learning rate and weight decay       
    criterion = nn.CrossEntropyLoss() # loss function
    
    # sent to model to device and define the 
    model = model.to(device) # move model to GPU
    
    # copy.deepcopy(): creates a deep copy of a Python object, including all nested objects and their contents. It ensures that changes made to the copied object do not affect the original object and vice versa.
    best_model_wts = copy.deepcopy(model.state_dict()) # copy the model weights
    
    
    # initialize the metrics
    best_acc = 0.0
    #the loss parameters list of validation dataset 
    train_loss_all = [] 
    val_loss_all   = []
    
    train_acc_all = []
    val_acc_all   = []

    # time the training processing
    since = time.time()

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch,num_epochs-1))
        print("-"*10)
        
        # initialize the parameters like the loss and correct of training and validation data set
        train_loss = 0
        train_corrects = 0 
        
        val_loss = 0
        val_corrects = 0 
                
        train_num = 0 # the number of training data set
        val_num = 0 # the number of validation data set
        
        # to train each mini-batch
        for step,(b_x,b_y) in enumerate(train_dataloader):
            b_x = b_x.to(device)
            b_y = b_y.to(device)            
            
            model.train()  
                      
            output = model(b_x) # forwards computation
            pre_lab = torch.argmax(output,dim=1) # get the labels            
            loss = criterion(output,b_y) # the loss for one batch
            
            
            # 将梯度初始化为0
            optimizer.zero_grad()
            # 反向传播计算
            loss.backward()
            # 根据网络反响传播的梯度信息来更新网络的参数,以起到降低LOSS函数计算值的作用
            optimizer.step() # update the parameters
            #对损失函数进行累加
            train_loss += loss.item() * b_x.size(0)
            #如果预测正确,则准确值 +1
            train_corrects += torch.sum(pre_lab == b_y.data)
            # 当前累积参与训练的样本量
            train_num += b_x.size(0) # 取第一维的数值,代表数据样本数量?  
                             
        # for evaluation  
        for step,(b_x,b_y) in enumerate(val_dataloader):            
            b_x = b_x.to(device)
            b_y = b_y.to(device)            
            
            model.eval()
            output = model(b_x)
            pre_lab = torch.argmax(output,dim=1) 
            loss = criterion(output,b_y)  #tips: 验证过程只计算结果,不参与前向传播,所以验证阶段就不算
            # 对损失函数进行累加
            val_loss += loss.item() * b_x.size(0)
            # 如果预测正确,则
            val_corrects += torch.sum(pre_lab == b_y.data) #累加            
            val_num += b_x.size(0)
            
        # 计算并保存每一次迭代的loss值和 accuracy
            
        train_loss_all.append(train_loss /train_num) #一直都在累加的loss值/总样本量
        train_acc_all.append(train_corrects.double().item()/train_num)
        val_loss_all.append(val_loss /val_num) #一直都在累加的loss值/总样本量
        val_acc_all.append(val_corrects.double().item()/val_num)    
        
        print("{} Train Loss:{:.4f} Train Acc:{:.4f}".format(epoch,train_loss_all[-1], train_acc_all[-1]))
        print("{} Val   Loss:{:.4f} Val   Acc:{:.4f}".format(epoch,val_loss_all[-1],   val_acc_all[-1]))        
        
        # 寻找最高准确度的权重
        if val_acc_all[-1] > best_acc:
            # 保存当前的最高准确度
            best_acc = val_acc_all[-1]
             # 保存当前的模型的参数
            best_model_wts = copy.deepcopy(model.state_dict())
            
        # the time consumption for one epoch
        time_use = time.time() - since
        print("train and validation time consumption:{:.0f}m{:.0f}s".format(time_use//60,time_use%60))
        
        # select the best parameters
        # load the best parameters with the best accuracy
    # model.load_state_dict(best_model_wts)      
    torch.save(best_model_wts,f"/home/dxj/code/TORCH/mobilenetv2/MobileNetV2_trim.pth ") #.pth are weights file
    train_process = pd.DataFrame(data ={"epoch":range(num_epochs),
                                            "train_loss_all":train_loss_all,
                                            "val_loss_all":val_loss_all,
                                            "train_acc_all":train_acc_all,
                                            "val_acc_all":val_acc_all,})        
    return train_process
    

3. 训练过程及结果打印

def matplot_acc_loss(train_process):
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(train_process["epoch"],train_process.train_loss_all,"ro-",label = "train_loss")
    plt.plot(train_process["epoch"],train_process.val_loss_all,"bs-",label = "val_loss")   
    plt.legend()
    plt.xlabel("epoch")        
    plt.ylabel("loss")
    
    
    plt.subplot(1,2,2)
    plt.plot(train_process["epoch"],train_process.train_acc_all,"ro-",label = "train_acc")
    plt.plot(train_process["epoch"],train_process.val_acc_all,"bs-",label = "val_acc")   
    plt.legend()
    plt.xlabel("epoch")        
    plt.ylabel("acc")
    
    plt.legend()
    plt.show()

4. 调用

if __name__=="__main__":
    model = yourmodel()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = yourmodel(num_classes=1000, alpha=1.0, round_nearest=8)
    model = model.to(device)
    # print the architecture to check again
    print(summary(model, (3, 224, 224))) 
    
    train_dataloader,val_dataloader = train_val_data_process()
    train_process = train_model_process(model,train_dataloader,val_dataloader,num_epochs=20)
    matplot_acc_loss(train_process)

补充

1) data loader

链接: torch.utils.data.DataLoader

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

2) 创建类别索引文件

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
class_list = train_dataset.class_to_idx   
cla_dict = dict((val, key) for key, val in flower_list.items())

# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

类似这样的索引文件

总结

万能模板,后面有需要再补充。

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值