基于ResNet实现猫十二分类--附加UI界面(96%正确率)

 ------------HOOK团队出品

1.导入所需的python库

from collections import Counter
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision
from torch.utils.data import WeightedRandomSampler
from torchvision import transforms
from tqdm import tqdm

2.超参数的设定

# 超参
DEVICE = torch.device('cuda' if torch.cuda.is_available() else "cpu")
LR = 0.005
EPOCH = 50
BTACH_SIZE = 32
train_root ="you_train_path"
batch_size = 8

3.数据的加载和处理

# 数据加载及处理
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0), ratio=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
    torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

# 图像读取转换
all_data = torchvision.datasets.ImageFolder(
    root=train_root,
    transform=train_transform
)

#print(all_data.class_to_idx)

# 计算每个类别的样本数量
class_counts = Counter(all_data.targets)  # 假设类别信息在 all_data.targets 中

# 计算每个类别的样本权重
weights = [1.0 / class_counts[class_idx] for class_idx in all_data.targets]

# 创建了一个WeightedRandomSampler对象。这个对象是PyTorch的数据加载工具中的一个功能,可以用于从带有权重的列表中进行采样
# replacement=True表示可以进行有放回的采样
sampler = WeightedRandomSampler(weights, len(all_data), replacement=True)

# 使用采样器进行数据集划分,从整体数据集all_data中根据采样器sampler的采样结果选取部分数据,得到训练集train_data。
train_data = torch.utils.data.Subset(all_data, list(sampler))

# 将采样器sampler转化为一个列表,其中的元素是采样得到的样本索引。
sampler_indices = list(sampler)

# 生成一个未被采样到的样本索引的列表valid_indices。
valid_indices = [idx for idx in range(len(all_data)) if idx not in sampler_indices]

# 使用未被采样的样本索引生成验证集valid_data。
valid_data = torch.utils.data.Subset(all_data, valid_indices)


# 训练数据集加载
train_set = torch.utils.data.DataLoader(
    train_data,
    batch_size=BTACH_SIZE,
    shuffle=True
)

# 测试集加载
test_set = torch.utils.data.DataLoader(
    valid_data,
    batch_size=BTACH_SIZE,
    shuffle=True
)

4.模型的训练和预测

# 训练
def train(model1, device, dataset, optimizer1, epoch1):
    global loss
    model1.train()

    correct = 0
    all_len = 0
    # 'tqdm'是一个用于显示进度条的库,它接受任何可迭代对象,并在遍历这个可迭代对象时显示一个进度条。
    for i, (x, y) in tqdm(enumerate(dataset)):
        x, y = x.to(device), y.to(device)
        optimizer1.zero_grad()
        output = model1(x)
        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(y.view_as(pred)).sum().item()
        all_len += len(x)
        loss = nn.CrossEntropyLoss()(output, y)
        loss.backward()
        optimizer1.step()


    print(f"第 {epoch1} 次训练的Train真实:{100.  * correct / all_len:.2f}%")

# 测试机验证
def vaild(model, device, dataset):
    model.eval()
    global loss
    correct = 0
    test_loss = 0
    all_len = 0
    with torch.no_grad():
        for i, (x, target) in enumerate(dataset):
            x, target = x.to(device), target.to(device)

            output = model(x)
            loss = nn.CrossEntropyLoss()(output, target)
            test_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            all_len += len(x)
    print(f"Test 真实:{100. * correct / all_len:.2f}%")
    return 100. * correct / all_len

5.ResNet50迁移学习

使用pretrain = True的方式得到预训练模型,更改全连接层的输出维度。

或者weights='ResNet50_Weights.DEFAULT'

两者的区别是:区别在于是否使用默认的预训练权重。resnet50(pretrained=True)使用的是ImageNet数据集上训练得到的默认权重,而weights='ResNet50_Weights.DEFAULT'则是使用在ImageNet数据集上训练得到的自定义权重。

model_1 = torchvision.models.resnet50(pretrained=True) #weights='ResNet50_Weights.DEFAULT'
model_1.fc = nn.Sequential(
    nn.Linear(2048, 12)
)

model_1.to(DEVICE)
optimizer = optim.SGD(model_1.parameters(), lr=LR, momentum=0.09)

模型的训练和保存

max_accuracy = 90.0  # 设定保存模型的阈值
best_model = None

for epoch in range(1, EPOCH + 1):
    train(model_1, DEVICE, train_set, optimizer, epoch)
    accu = vaild(model_1, DEVICE, test_set)
    if accu > max_accuracy:
        max_accuracy = accu
        best_model = model_1.state_dict()  # 或者使用 torch.save() 保存整个模型

# 保存最优模型
torch.save(best_model, r"E:\日常练习\pytorch_Project\best_model_train1.pth")

基于gradio做的UI界面

1.所需python库

# 导入必要的库
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import tempfile
from torch import nn

2.模型的加载

#加载模型
def load_pretrained_resnet():
    # 确定使用的设备: GPU (如果可用) 或 CPU
    device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
    model_1 = torchvision.models.resnet50(pretrained=True)
    model_1.fc = nn.Sequential(nn.Linear(2048, 12))
    # model_1.load_state_dict(torch.load(r'E:\日常练习\pytorch_Project\best_model_train.pth'))
    model_1.to(device)
    model_1.load_state_dict(torch.load(r"C:\Users\Acer\Desktop\Cat-12\best_model_train99.71.pth"))
    # 确保模型处于评估模式
    model_1.eval()
    return model_1

3.模型预测结果

#模型预测结果
def load_imagenet_labels(filename=r"C:\Users\Acer\Desktop\Cat-12\cat12.txt"):
    """从给定文件名读取ImageNet的标签"""
    with open(filename, 'r',encoding="utf-8") as f:
        labels = [line.strip() for line in f.readlines()]
    return labels


def serve_chicken(image_path):
    """
    接收一个图像路径,用ResNet模型进行预测,然后返回预测结果。

    参数:
    - image_path: 输入图像的路径

    返回:
    - prediction: 预测的类别名
    """
    # 确定使用的设备: GPU (如果可用) 或 CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = load_pretrained_resnet()

    # (transforms)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # 打开图像文件
    image = Image.open(image_path).convert("RGB")
    image = preprocess(image).unsqueeze(0).to(device)  # 确保图像在同一设备上

    # 预测图像)
    with torch.no_grad():
        outputs = model(image)
        _, predicted = outputs.max(1)
        # 假设你有一个标签到名字的映射,这里只是一个示例
        labels = load_imagenet_labels()
        prediction = labels[predicted[0]]

    return prediction

4.gradio页面(简单版)

#gradio
def gradio_wrapper(image_array):
    """
    一个包装函数,接收Gradio传入的numpy图像,保存为临时文件,然后传给serve_chicken函数。
    """
    # 使用PIL将numpy数组保存为临时图片文件
    global img
    temp_filename = tempfile.mktemp(suffix=".jpg")
    if image_array is not None:
        img = Image.fromarray(image_array.astype('uint8'))
    else:
        print("Error: image_array is None!")

    img.save(temp_filename)
    # 调用原来的serve_chicken函数
    result = serve_chicken(temp_filename)

    return result

# 下面这段是简短的UI页面
def gradio_interface():
    """使用Gradio创建一个交互式界面并展示模型的预测结果。"""

    # 使用Gradio的新API来定义输入和输出组件
    image_input = gr.components.Image(shape=(256, 256), type="numpy")
    label_output = gr.components.Label(num_top_classes=3)

    # 创建Gradio界面
    interface = gr.Interface(fn=gradio_wrapper, inputs=image_input, outputs=label_output, live=True)

    # 启动Gradio界面
    interface.launch()

5.主函数

if __name__ == "__main__":
    resnet_model = load_pretrained_resnet()
    print("ResNet模型已加载!")
    result = serve_chicken(r"D:\桌面\cat12\cat_12_train\A.jpg")
    print(f"这个猫的种类是【{result}】!")
    gradio_interface()

6.UI界面展示

 7.优化好的UI页面

如果想要实现这个优化好的ui界面,12类猫的简介及其特征的获取,需要接入文心一言的token

获取方法:飞浆星河社区个人中心-->访问令牌(100万次免费使用机会)

#文心一言
def wenxin(question):
    import erniebot
    erniebot.api_type = 'aistudio'
    erniebot.access_token = "<you_token>"

    response = erniebot.ChatCompletion.create(
        model='ernie-bot',
        messages=[{'role': 'user', 'content': "请给我讲解{}的具体特征,越详细越好,开头词为亲爱的用户您查询的猫品种简介如下:".format(question)}],
    )
    return response.result

def teac_math():
    with gr.Blocks() as demo:
        image_result_list = get_img_list() #样例图展示
        state_image_list = gr.State(value=image_result_list) #样例图展示
        with gr.Row(equal_height=False):
            with gr.Column(variant='panel'):
                gr.Markdown('''"伯曼猫","俄罗斯蓝猫","埃及猫","孟买猫","孟加拉豹猫","布偶猫"<br/>"无毛猫","暹罗猫","波斯猫","缅因猫","英国短毛猫","阿比西亚猫"''')
                image_results = gr.Gallery(value=image_result_list, label='样例图', allow_preview=False,
                                           columns=6, height=250)
                image_input = gr.Image(label="传入需要预测的猫的图片")
                text_button_img = gr.Button("确定上传")
                text_output = gr.Textbox(label='上传的猫的图片的预测结果为')

            with gr.Column(variant='panel'):
                with gr.Box():
                    with gr.Row():
                        gr.Markdown("请输入需要查询猫品种的名字,获得相应的简介")
                    text_input = gr.Textbox(label="请输入要查询猫品种的名字")
                    text_button = gr.Button("确定查询")
                text_outputs = gr.Textbox(label="您查询的猫的品种的简介如下:",lines=10)
        #right
        image_results.select(get_selected_image, state_image_list, queue=False)   #样例图展示
        text_button_img.click(fn=gradio_wrapper, inputs=image_input, outputs=text_output)
        #left
        text_button.click(fn=wenxin, inputs=text_input ,outputs=text_outputs)
    return demo

8.优化后的主函数

if __name__ == "__main__":

    with gr.Blocks(css='style.css') as demo:
        gr.Markdown(
            "# <center> \N{fire} 基于ResNet50的猫十二分类 </center>")
        with gr.Tabs():
            with gr.TabItem('\N{clapper board} 猫十二分类预测'):
                teac_math()
    demo.launch()

9.优化后的页面效果展示

--------------HOOK团队出品 

ResNet是一种深度残差神经网络模型,适用于图像分类任务。而针对12分类任务,可以使用ResNet模型来进行训练与预测。 首先,需要收集一定数量的的图像数据集,包含12个不同类别的的图片。可以从互联网上搜集这些图片,确保每个类别都有足够数量的样本。 然后,将图像数据集分为训练集和测试集,通常可以按照70%的比例将数据分为训练集和30%的比例分为测试集。这样可以用训练集来训练模型,用测试集来评估模型的性能。 接下来,使用一个预训练的ResNet模型(如ResNet-50或ResNet-101)作为基础模型。该预训练模型已经在大量图像数据上进行了训练,因此可以提取出图像的特征。 然后,根据数据集的标签信息,调整预训练模型以适应12分类任务。可以使用全连接层或者其他分类器对模型的输出进行调整,确保可以对每个的类别进行准确分类。 在训练过程中,通过反向传播和梯度下降算法对模型的参数进行优化,以使得模型在训练集上的表现越来越好。 训练完成后,使用测试集对模型进行评估,可以计算模型在测试集上的分类准确率、精度、召回率等指标,来评估模型在12分类任务上的性能。 最终,当有新的的图片需要分类时,使用已经训练好的ResNet模型对这些图片进行预测,即可得到图片所属的的类别。 总之,通过使用ResNet模型进行训练和预测,可以有效地进行12分类任务,并且在合适的数据集和参数调整下,可以获得较好的分类效果。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值