PyTorch 训练图像分类网络(代码讲解)

整个工程文件已放到Github上
https://github.com/yaoyi30/PyTorch_Image_Classification

一、训练图像分类网络主要流程

  1. 构建数据集
  2. 数据预处理、包括数据增强和数据标准化和归一化
  3. 构建网络模型
  4. 设置学习率、优化器、损失函数等超参数
  5. 训练和验证

二、各个流程简要说明

1. 构建数据集

本文使用kaggle上的10种猴子分类数据集,网址为https://www.kaggle.com/datasets/slothkong/10-monkey-species
在这里插入图片描述
在工程目录下,新建datasets文件夹,在文件夹内分别新建train和val文件夹,用来放训练和验证数据,train和val文件夹下分别放有十种猴子图像数据,分别以该类别的名称命名,结构如下:

datasets/
  train/   # train images
     n0/
        img1.jpg
        img2.jpg
         .
         .
         .
     n1/
        .
        .
        .
  val/     # val images
     n0/
        img1.jpg
        img2.jpg
         .
         .
         .
     n1/
        .
        .
        .

2. 数据预处理

将图像resize到统一大小,之后转为tensor格式再进行标准化,预处理之后的图片可以正常输入网络,对于训练集可以采取一些数据增强手段来增强网络的泛化能力,验证集不做数据增强。

    #训练数据预处理、数据增强设置
    train_transform =  transforms.Compose([
                                       transforms.Resize(args.input_size), #图像resize到统一大小
                                       transforms.RandomHorizontalFlip(), #数据增强,水平翻转
                                       transforms.ToTensor(), #转为tensor格式,值变为0-1之间
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #标准化
                                    ])
    #验证数据预处理
    val_transform =  transforms.Compose([
                                       transforms.Resize(args.input_size), #图像resize到统一大小
                                       transforms.ToTensor(), #转为tensor格式,值变为0-1之间
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #标准化
                                    ])

3. 构建网络模型

本文搭建了一个三层卷积神经网路,命名为Simplify_Net。

    model = Simplify_Net(args.nb_classes)

4. 设置学习率、优化器、损失函数等超参数

    #定义损失函数,选用交叉熵损失函数
    loss_function = nn.CrossEntropyLoss()
    #定义优化器(初始学习率和权重衰减值)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
    #定义学习率类型,此处选用余弦退火学习率,设置最大值
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.max_lr, total_steps=args.epochs, verbose=True)

5. 训练和验证

	#训练和验证模型,具体函写在了utils.py文件中
    history = train_and_val(args.epochs, model, train_loader, len_train,val_loader, len_val,loss_function, optimizer,scheduler,args.output_dir,device)   

三、工程代码文件详细讲解

train.py

定义训练的入口函数,以及训练所需要的流程

1. 导入相应的库和文件

import os
import torch
from torchvision import transforms, datasets
import torch.nn as nn
from models.Simplify_Net import Simplify_Net
from utils import train_and_val,plot_acc,plot_loss,plot_lr
import argparse
import numpy as np

2. 训练参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Image Classification Train', add_help=False)
    #批次大小设置
    parser.add_argument('--batch_size', default=32, type=int,help='Batch size for training')
    #训练轮数设置
    parser.add_argument('--epochs', default=100, type=int)
    #网络输入图像大小设置
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    #数据集路径设置
    parser.add_argument('--data_path', default='./datasets/', type=str,help='dataset path')
    #初始学习率大小设置(采用余弦退火学习率)
    parser.add_argument('--init_lr', default=1e-5, type=float,help='intial lr')
    #最大学习率大小设置(采用余弦退火学习率)
    parser.add_argument('--max_lr', default=1e-3, type=float,help='max lr')
    #权重衰减值设置(是一个正则化技术,作用是抑制模型的过拟合,以此来提高模型的泛化性)
    parser.add_argument('--weight_decay', default=1e-5, type=float,help='weight decay')
    #类别设置
    parser.add_argument('--nb_classes', default=10, type=int,help='number of the classification types')
    #模型保存路径设置
    parser.add_argument('--output_dir', default='./output_dir',help='path where to save, empty for no saving')
    #训练设备设置(gpu或者cpu)
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')
    #加载数据子进程的数量
    parser.add_argument('--num_workers', default=4, type=int)

    return parser

3. 定义主函数

def main(args):
    #使用cpu或者gpu训练
    device = torch.device(args.device)
    #创建模型保存文件夹
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    #训练数据预处理、数据增强设置
    train_transform =  transforms.Compose([
                                       transforms.Resize(args.input_size),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                    ])
    #验证数据预处理
    val_transform =  transforms.Compose([
                                       transforms.Resize(args.input_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                    ])
    #根据文件夹读取训练数据
    train_dataset = datasets.ImageFolder(os.path.join(args.data_path,'train'), transform=train_transform)
    #加载训练集(图像顺序打乱)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.num_workers)
    #训练集图像数量
    len_train = len(train_dataset)
    #根据文件夹读取验证数据
    val_dataset = datasets.ImageFolder(os.path.join(args.data_path,'val'), transform=val_transform)
    #加载验证集(图像顺序不打乱)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.num_workers)
    #验证集图像数量
    len_val = len(val_dataset)
    #定义分类网络,输入类别数
    model = Simplify_Net(args.nb_classes)
    #定义损失函数
    loss_function = nn.CrossEntropyLoss()
    #定义优化器(初始学习率和权重衰减值)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
    #定义学习率类型,此处选用余弦退火学习率,设置最大值
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, args.max_lr, total_steps=args.epochs, verbose=True)
    #训练和验证模型,具体函写在了utils.py文件中
    history = train_and_val(args.epochs, model, train_loader, len_train,val_loader, len_val,loss_function, optimizer,scheduler,args.output_dir,device)
    #打印损失值曲线,具体函写在了utils.py文件中
    plot_loss(np.arange(0,args.epochs),args.output_dir, history)
    #打印准确率曲线,具体函写在了utils.py文件中
    plot_acc(np.arange(0,args.epochs),args.output_dir, history)
    #打印学习率曲线,具体函写在了utils.py文件中
    plot_lr(np.arange(0,args.epochs),args.output_dir, history)

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行train.py,训练时打印的信息,包括每一轮的学习率,训练集和验证集指标,运行时间等
在这里插入图片描述

Simplify_Net.py

定义网络结构,本文定义一个简单的三层卷积神经网络

import torch
import torch.nn as nn

class Simplify_Net(nn.Module):
    def __init__(self, num_classes=2):
        super(Simplify_Net, self).__init__()
        #卷积层1
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=2)
        #批归一化层1
        self.bn1 = nn.BatchNorm2d(16)
        #激活函数层1
        self.relu1 = nn.ReLU(inplace=True)
        #最大池化层1
        self.maxpool1 = nn.MaxPool2d(kernel_size=2,stride=2)

        #卷积层2
        self.conv2 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=2)
        #批归一化层2
        self.bn2 = nn.BatchNorm2d(16)
        #激活函数层2
        self.relu2 = nn.ReLU(inplace=True)
        #最大池化层  
        self.maxpool2 = nn.MaxPool2d(kernel_size=2,stride=2)

        #卷积层3
        self.conv3 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=2)
        #批归一化层3
        self.bn3 = nn.BatchNorm2d(16)
        #激活函数层3
        self.relu3 = nn.ReLU(inplace=True)
 
        #全局平均池化层
        self.avgpool = nn.AvgPool2d(6)
        #全连接层
        self.fc = nn.Linear(16, num_classes)

    def forward(self, x):

        x = self.maxpool1(self.relu1(self.bn1(self.conv1(x))))
        x = self.maxpool2(self.relu2(self.bn2(self.conv2(x))))
        x = self.relu3(self.bn3(self.conv3(x)))

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

utils.py

定义具体的训练、验证以及绘制指标曲线的函数

1. 导入相应的库和文件

import os
import torch
import time
from tqdm import tqdm
import matplotlib.pyplot as plt

2. 训练验证函数

def train_and_val(epochs, model, train_loader, len_train,val_loader, len_val,criterion, optimizer,scheduler,output_dir,device):
    #定义训练集损失值列表
    train_loss = []
    #定义验证集损失值列表
    val_loss = []
    #定义训练集准确率列表
    train_acc = []
    #定义验证集准确率列表
    val_acc = []
    #定义学习率列表
    learning_rate = []
    #定义验证集最佳准确率变量
    best_acc = 0
    #将模型加载到设备中(cpu or gpu)
    model.to(device)
    #开始计时,主要记录整个训练过程
    fit_time = time.time()
    #开始训练
    for e in range(epochs):
        #内存释放
        torch.cuda.empty_cache()
        #开始计时,主要记录每一轮训练的时间
        since = time.time()
        training_loss = 0
        training_acc = 0
        #把模型调整成为训练模式
        model.train()
        with tqdm(total=len(train_loader)) as pbar:
            #遍历训练数据
            for image, label in train_loader:
                #将训练数据中的图像以及标签加载到设备中(cpu or gpu),设备类型必须和模型一致
                image = image.to(device)
                label = label.to(device)
                #模型推理                
                output = model(image)
                #损失值计算  
                loss = criterion(output, label)
                #获取预测结果  
                _,predicted = torch.max(output, dim=1)
                #进行反向传播,更新模型参数
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #计算每一个batch的准确率和损失值并相加,用来计算一整轮的准确率和损失值
                training_loss += loss.item()
                training_acc += torch.eq(predicted, label).sum().item()
                pbar.update(1)
        #把模型调整成为验证模式
        model.eval()
        validation_loss = 0
        validation_acc = 0

        with torch.no_grad():
            with tqdm(total=len(val_loader)) as pb:
                #同上
                for image, label in val_loader:
                    image = image.to(device)
                    label = label.to(device)
                    output = model(image)

                    # loss
                    loss = criterion(output, label)
                    _, predicted = torch.max(output, dim=1)

                    validation_loss += loss.item()
                    validation_acc += torch.eq(predicted, label).sum().item()
                    pb.update(1)
        #列表中加入每一轮的损失值
        train_loss.append(training_loss / len(train_loader))
        val_loss.append(validation_loss / len(val_loader))
        #列表中加入每一轮的准确率
        train_acc.append(training_acc / len_train)
        val_acc.append(validation_acc / len_val)
        #列表中加入每一轮的学习率
        learning_rate.append(scheduler.get_last_lr())
        #保存两个模型,一种是最新的模型,一种是指标最好的模型,通过验证集准确率来判断
        torch.save(model.state_dict(), os.path.join(output_dir,'last.pth'))
        if best_acc <(validation_acc / len_val):
            torch.save(model.state_dict(), os.path.join(output_dir,'best.pth'))

        #打印每一轮的指标
        print("Epoch:{}/{}..".format(e + 1, epochs),
              "Train Acc: {:.3f}..".format(training_acc / len_train),
              "Val Acc: {:.3f}..".format(validation_acc / len_val),
              "Train Loss: {:.3f}..".format(training_loss / len(train_loader)),
              "Val Loss: {:.3f}..".format(validation_loss / len(val_loader)),
              "Time: {:.2f}s".format((time.time() - since)))
        #每一轮训练完毕更新学习率的值
        scheduler.step()
    #返回记录的所有参数、指标列表
    history = {'train_loss': train_loss, 'val_loss': val_loss ,'train_acc': train_acc, 'val_acc': val_acc,'lr':learning_rate}
    #整个训练过程结束时记录此刻时间,并计算用时
    print('Total time: {:.2f} m'.format((time.time() - fit_time) / 60))

    return history

3. 打印损失值曲线

def plot_loss(x,output_dir, history):
    plt.plot(x, history['val_loss'], label='val', marker='o')
    plt.plot(x, history['train_loss'], label='train', marker='o')
    plt.title('Loss per epoch')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.savefig(os.path.join(output_dir,'loss.png'))
    plt.clf()

在这里插入图片描述

4. 打印准确率曲线

def plot_acc(x,output_dir, history):
    plt.plot(x, history['train_acc'], label='train_acc', marker='x')
    plt.plot(x, history['val_acc'], label='val_acc', marker='x')
    plt.title('Acc per epoch')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.savefig(os.path.join(output_dir,'acc.png'))
    plt.clf()

网络结构较为简单,因此准确率不是特别的高
在这里插入图片描述

5. 打印学习率曲线

def plot_lr(x,output_dir,  history):
    plt.plot(x, history['lr'], label='learning_rate', marker='x')
    plt.title('learning rate per epoch')
    plt.ylabel('Learning_rate')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.savefig(os.path.join(output_dir,'learning_rate.png'))
    plt.clf()

从学习率曲线可以看出,约前30轮为warmup阶段,最大学习率为0.001
在这里插入图片描述

predict.py

进行单张图片预测

1. 导入相应的库和文件

import argparse
import torch
import torch.nn as nn
import torchvision.transforms as T
from models.Simplify_Net import Simplify_Net
from PIL import Image

2. 单张预测参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Predict Image', add_help=False)
    #需要预测的图像路径
    parser.add_argument('--image_path', default='./n6040.jpg', type=str, metavar='MODEL',help='Name of model to train')
    #输入图像大小,与训练一致
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    #选择训练好的模型
    parser.add_argument('--weights', default='./output_dir/last.pth', type=str,help='dataset path')
    #类别
    parser.add_argument('--nb_classes', default=10, type=int,help='number of the classification types')
    #运行设备
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')

    return parser

3. 定义主函数

def main(args):
    #使用cpu或者gpu预测
    device = torch.device(args.device)
    #读取图像
    image = Image.open(args.image_path).convert('RGB')
    #图像预处理
    transforms = T.Compose([
        T.Resize(args.input_size),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])
    #类别名称
    labels_name = ['n0','n1','n2','n3','n4','n5','n6','n7','n8','n9']
    #定义网络
    model = Simplify_Net(args.nb_classes)
    #加载权重
    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)
    #将网络及其权重加载到指定设备上(cpu or gpu)
    model.to(device)
    #设置为验证模式
    model.eval()
    #定义归一化指数函数,将概率值变为0-1之间
    act = nn.Softmax(dim=1)
    #定将图像处理为网络输入需要的的tensor
    input_tensor = transforms(image).unsqueeze(0).to(device)
    with torch.no_grad():
        #模型推理
        outputs = act(model(input_tensor))
        #获取预测结果以及概率值
        _, predicted = torch.max(outputs, 1)
        predicted = predicted.cpu().numpy()[0]
        print('name is: ' + labels_name[predicted])
        print('prob is: ' + str(outputs.cpu().numpy()[0][predicted]))

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行predict.py,打印模型预测的结果
在这里插入图片描述

eval.py

进行模型整体指标评价

1. 导入相应的库和文件

import argparse
from sklearn.metrics import confusion_matrix, classification_report,accuracy_score
from torchvision import transforms, datasets
import torch
import os
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
from models.Simplify_Net import Simplify_Net
import matplotlib.pyplot as plt
import seaborn as sns

2. 模型评价参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Eval Model', add_help=False)
    #批次大小设置
    parser.add_argument('--batch_size', default=8, type=int,help='Batch size for training')
    #输入图像大小
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    #验证集文件夹所在地址
    parser.add_argument('--data_path', default='./datasets/', type=str,help='dataset path')
    #选择训练好的模型
    parser.add_argument('--weights', default='./output_dir/best.pth', type=str,help='dataset path')
    #类别
    parser.add_argument('--nb_classes', default=10, type=int,help='number of the classification types')
    #运行设备
    parser.add_argument('--device', default='cuda',help='device to use for training / testing')
    #加载数据子进程的数量
    parser.add_argument('--num_workers', default=4, type=int)

    return parser

3. 定义主函数

def main(args):
    #同上
    device = torch.device(args.device)

    val_transform =  transforms.Compose([
                                       transforms.Resize(args.input_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                    ])

    val_dataset = datasets.ImageFolder(os.path.join(args.data_path,'val'), transform=val_transform)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.num_workers)

    model = Simplify_Net(args.nb_classes)

    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)

    model.to(device)
    model.eval()

    classes = val_dataset.classes

    act = nn.Softmax(dim=1)

    y_true, y_pred = [], []
    with torch.no_grad():
        with tqdm(total=len(val_loader)) as pbar:
            for images, labels in val_loader:
                outputs = act(model(images.to(device)))
                _, predicted = torch.max(outputs, 1)
                predicted = predicted.cpu()
                y_pred.extend(predicted.numpy())
                y_true.extend(labels.cpu().numpy())
                pbar.update(1)
    #计算总体准确率
    ac = accuracy_score(y_true, y_pred)
    #计算每一类的准确率、召回率以及F1值
    cr = classification_report(y_true, y_pred, target_names=classes, output_dict=True)
    #将结果保存在csv文件中
    df = pd.DataFrame(cr).transpose()
    df.to_csv("result.csv", index=True)
    print("Accuracy is :", ac)
    #生成混淆矩阵并可视化
    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, xticklabels=classes, yticklabels=classes, cmap='Blues', fmt="d")

    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig('confusion_matrix.png')
    plt.clf()

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行eval.py,打印模型在验证集上的准确率,同时会生成混淆矩阵以及保存每个类别准召率、F1值的csv文件
在这里插入图片描述
通过将模型的预测结果与真实标签进行比较,可以得出混淆矩阵(Confusion Matrix),以帮助我们了解模型在不同类别上的分类情况。
在这里插入图片描述生成的result.csv文件,里面详细记录了每一类的准确率、召回率以及F1值
在这里插入图片描述

export_onnx.py

将训练好的模型转onnx格式,以进行后续应用

1. 导入相应的库和文件

import torch
from models.Simplify_Net import Simplify_Net
import argparse

2. 转onnx模型参数设置

def get_args_parser():
    parser = argparse.ArgumentParser('Export Onnx', add_help=False)
    #输入图像大小
    parser.add_argument('--input_size', default=[224,224],nargs='+',type=int,help='images input size')
    #选择训练好的模型
    parser.add_argument('--weights', default='./output_dir/best.pth', type=str,help='dataset path')
    #类别
    parser.add_argument('--nb_classes', default=10, type=int,help='number of the classification types')

    return parser

3. 定义主函数

def main(args):
    #定义一个输入tensor
    x = torch.randn(1, 3, args.input_size[0],args.input_size[1])
    #定义输入名字
    input_names = ["input"]
    #定义输出名字
    out_names = ["output"]
    #定义网络
    model = Simplify_Net(args.nb_classes)
    #加载权重
    checkpoint = torch.load(args.weights, map_location='cpu')
    msg = model.load_state_dict(checkpoint, strict=True)
    print(msg)
    #将模型设置为验证模式
    model.eval()
    #转onnx模型
    torch.onnx.export(model, x, args.weights.replace('pth','onnx'), export_params=True, training=False, input_names=input_names, output_names=out_names)
    print('please run: python -m onnxsim test.onnx test_sim.onnx\n')

4. 开始执行

if __name__ == '__main__':
    #获取训练参数
    args = get_args_parser()
    #解析训练参数
    args = args.parse_args()
    #训练参数传入主函数
    main(args)

运行export_onnx.py,之后进行模型的简化
在这里插入图片描述
简化之前(左)和之后(右)的onnx模型结构对比
在这里插入图片描述

  • 15
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
srcnn超分辨率pytorch代码是用于实现图像超分辨率(Super Resolution)的一种深度学习模型。下面我将逐行讲解这个代码。 首先,代码导入了需要的库和模块,包括torch、torchvision等,以及一些辅助函数。 接下来,定义了一个名为SRCNN的类。这个类继承自nn.Module类,用来构建SRCNN模型。在这个类的构造函数中,首先调用父类的构造函数初始化模型;然后定义了三个卷积层,分别是nn.Conv2d,并且设置输入通道数、输出通道数、卷积核大小和步长;接着定义了ReLU激活函数;最后定义了一个反卷积层nn.ConvTranspose2d,用于得到最终的超分辨率图像。 在类的前面还定义了两个辅助函数,即adjust_scale和normalize,分别用于将图像缩放到指定尺寸和对图像进行归一化处理。 接下来,定义了一个名为train的函数,该函数用于训练模型。在函数中,首先根据指定的超参数设置模型的训练参数,如学习率、损失函数、优化器等;然后加载训练数据集和验证数据集,采用DataLoader进行批量加载和预处理;随后,利用模型进行迭代训练,通过计算输出图像与标签图像之间的损失来更新模型参数;最后将训练得到的模型保存到指定路径。 最后,定义了一个名为test的函数,用于测试模型。在函数中,首先加载测试图像,并通过模型进行超分辨率处理;然后将超分辨率图像与原始图像进行比较,计算并打印出PSNR指标,评估超分辨率效果。 总结一下,这个SRCNN的pytorch代码包括了模型的构建、训练和测试三个主要部分,通过迭代训练和测试来实现图像的超分辨率。通过调整超参数、数据集和函数的调用,可以适应不同的超分辨率任务。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

姚先生97

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值