------------HOOK团队出品
1.导入所需的python库
from collections import Counter
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision
from torch.utils.data import WeightedRandomSampler
from torchvision import transforms
from tqdm import tqdm
2.超参数的设定
# 超参
DEVICE = torch.device('cuda' if torch.cuda.is_available() else "cpu")
LR = 0.005
EPOCH = 50
BTACH_SIZE = 32
train_root ="you_train_path"
batch_size = 8
3.数据的加载和处理
# 数据加载及处理
train_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224, scale=(0.6, 1.0), ratio=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0),
torchvision.transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0, hue=0),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
# 图像读取转换
all_data = torchvision.datasets.ImageFolder(
root=train_root,
transform=train_transform
)
#print(all_data.class_to_idx)
# 计算每个类别的样本数量
class_counts = Counter(all_data.targets) # 假设类别信息在 all_data.targets 中
# 计算每个类别的样本权重
weights = [1.0 / class_counts[class_idx] for class_idx in all_data.targets]
# 创建了一个WeightedRandomSampler对象。这个对象是PyTorch的数据加载工具中的一个功能,可以用于从带有权重的列表中进行采样
# replacement=True表示可以进行有放回的采样
sampler = WeightedRandomSampler(weights, len(all_data), replacement=True)
# 使用采样器进行数据集划分,从整体数据集all_data中根据采样器sampler的采样结果选取部分数据,得到训练集train_data。
train_data = torch.utils.data.Subset(all_data, list(sampler))
# 将采样器sampler转化为一个列表,其中的元素是采样得到的样本索引。
sampler_indices = list(sampler)
# 生成一个未被采样到的样本索引的列表valid_indices。
valid_indices = [idx for idx in range(len(all_data)) if idx not in sampler_indices]
# 使用未被采样的样本索引生成验证集valid_data。
valid_data = torch.utils.data.Subset(all_data, valid_indices)
# 训练数据集加载
train_set = torch.utils.data.DataLoader(
train_data,
batch_size=BTACH_SIZE,
shuffle=True
)
# 测试集加载
test_set = torch.utils.data.DataLoader(
valid_data,
batch_size=BTACH_SIZE,
shuffle=True
)
4.模型的训练和预测
# 训练
def train(model1, device, dataset, optimizer1, epoch1):
global loss
model1.train()
correct = 0
all_len = 0
# 'tqdm'是一个用于显示进度条的库,它接受任何可迭代对象,并在遍历这个可迭代对象时显示一个进度条。
for i, (x, y) in tqdm(enumerate(dataset)):
x, y = x.to(device), y.to(device)
optimizer1.zero_grad()
output = model1(x)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(y.view_as(pred)).sum().item()
all_len += len(x)
loss = nn.CrossEntropyLoss()(output, y)
loss.backward()
optimizer1.step()
print(f"第 {epoch1} 次训练的Train真实:{100. * correct / all_len:.2f}%")
# 测试机验证
def vaild(model, device, dataset):
model.eval()
global loss
correct = 0
test_loss = 0
all_len = 0
with torch.no_grad():
for i, (x, target) in enumerate(dataset):
x, target = x.to(device), target.to(device)
output = model(x)
loss = nn.CrossEntropyLoss()(output, target)
test_loss += loss.item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
all_len += len(x)
print(f"Test 真实:{100. * correct / all_len:.2f}%")
return 100. * correct / all_len
5.ResNet50迁移学习
使用pretrain = True的方式得到预训练模型,更改全连接层的输出维度。
或者weights='ResNet50_Weights.DEFAULT'
两者的区别是:区别在于是否使用默认的预训练权重。resnet50(pretrained=True)
使用的是ImageNet数据集上训练得到的默认权重,而weights='ResNet50_Weights.DEFAULT'
则是使用在ImageNet数据集上训练得到的自定义权重。
model_1 = torchvision.models.resnet50(pretrained=True) #weights='ResNet50_Weights.DEFAULT'
model_1.fc = nn.Sequential(
nn.Linear(2048, 12)
)
model_1.to(DEVICE)
optimizer = optim.SGD(model_1.parameters(), lr=LR, momentum=0.09)
模型的训练和保存
max_accuracy = 90.0 # 设定保存模型的阈值
best_model = None
for epoch in range(1, EPOCH + 1):
train(model_1, DEVICE, train_set, optimizer, epoch)
accu = vaild(model_1, DEVICE, test_set)
if accu > max_accuracy:
max_accuracy = accu
best_model = model_1.state_dict() # 或者使用 torch.save() 保存整个模型
# 保存最优模型
torch.save(best_model, r"E:\日常练习\pytorch_Project\best_model_train1.pth")
基于gradio做的UI界面
1.所需python库
# 导入必要的库
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import tempfile
from torch import nn
2.模型的加载
#加载模型
def load_pretrained_resnet():
# 确定使用的设备: GPU (如果可用) 或 CPU
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model_1 = torchvision.models.resnet50(pretrained=True)
model_1.fc = nn.Sequential(nn.Linear(2048, 12))
# model_1.load_state_dict(torch.load(r'E:\日常练习\pytorch_Project\best_model_train.pth'))
model_1.to(device)
model_1.load_state_dict(torch.load(r"C:\Users\Acer\Desktop\Cat-12\best_model_train99.71.pth"))
# 确保模型处于评估模式
model_1.eval()
return model_1
3.模型预测结果
#模型预测结果
def load_imagenet_labels(filename=r"C:\Users\Acer\Desktop\Cat-12\cat12.txt"):
"""从给定文件名读取ImageNet的标签"""
with open(filename, 'r',encoding="utf-8") as f:
labels = [line.strip() for line in f.readlines()]
return labels
def serve_chicken(image_path):
"""
接收一个图像路径,用ResNet模型进行预测,然后返回预测结果。
参数:
- image_path: 输入图像的路径
返回:
- prediction: 预测的类别名
"""
# 确定使用的设备: GPU (如果可用) 或 CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_pretrained_resnet()
# (transforms)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 打开图像文件
image = Image.open(image_path).convert("RGB")
image = preprocess(image).unsqueeze(0).to(device) # 确保图像在同一设备上
# 预测图像)
with torch.no_grad():
outputs = model(image)
_, predicted = outputs.max(1)
# 假设你有一个标签到名字的映射,这里只是一个示例
labels = load_imagenet_labels()
prediction = labels[predicted[0]]
return prediction
4.gradio页面(简单版)
#gradio
def gradio_wrapper(image_array):
"""
一个包装函数,接收Gradio传入的numpy图像,保存为临时文件,然后传给serve_chicken函数。
"""
# 使用PIL将numpy数组保存为临时图片文件
global img
temp_filename = tempfile.mktemp(suffix=".jpg")
if image_array is not None:
img = Image.fromarray(image_array.astype('uint8'))
else:
print("Error: image_array is None!")
img.save(temp_filename)
# 调用原来的serve_chicken函数
result = serve_chicken(temp_filename)
return result
# 下面这段是简短的UI页面
def gradio_interface():
"""使用Gradio创建一个交互式界面并展示模型的预测结果。"""
# 使用Gradio的新API来定义输入和输出组件
image_input = gr.components.Image(shape=(256, 256), type="numpy")
label_output = gr.components.Label(num_top_classes=3)
# 创建Gradio界面
interface = gr.Interface(fn=gradio_wrapper, inputs=image_input, outputs=label_output, live=True)
# 启动Gradio界面
interface.launch()
5.主函数
if __name__ == "__main__":
resnet_model = load_pretrained_resnet()
print("ResNet模型已加载!")
result = serve_chicken(r"D:\桌面\cat12\cat_12_train\A.jpg")
print(f"这个猫的种类是【{result}】!")
gradio_interface()
6.UI界面展示
7.优化好的UI页面
如果想要实现这个优化好的ui界面,12类猫的简介及其特征的获取,需要接入文心一言的token
获取方法:飞浆星河社区个人中心-->访问令牌(100万次免费使用机会)
#文心一言
def wenxin(question):
import erniebot
erniebot.api_type = 'aistudio'
erniebot.access_token = "<you_token>"
response = erniebot.ChatCompletion.create(
model='ernie-bot',
messages=[{'role': 'user', 'content': "请给我讲解{}的具体特征,越详细越好,开头词为亲爱的用户您查询的猫品种简介如下:".format(question)}],
)
return response.result
def teac_math():
with gr.Blocks() as demo:
image_result_list = get_img_list() #样例图展示
state_image_list = gr.State(value=image_result_list) #样例图展示
with gr.Row(equal_height=False):
with gr.Column(variant='panel'):
gr.Markdown('''"伯曼猫","俄罗斯蓝猫","埃及猫","孟买猫","孟加拉豹猫","布偶猫"<br/>"无毛猫","暹罗猫","波斯猫","缅因猫","英国短毛猫","阿比西亚猫"''')
image_results = gr.Gallery(value=image_result_list, label='样例图', allow_preview=False,
columns=6, height=250)
image_input = gr.Image(label="传入需要预测的猫的图片")
text_button_img = gr.Button("确定上传")
text_output = gr.Textbox(label='上传的猫的图片的预测结果为')
with gr.Column(variant='panel'):
with gr.Box():
with gr.Row():
gr.Markdown("请输入需要查询猫品种的名字,获得相应的简介")
text_input = gr.Textbox(label="请输入要查询猫品种的名字")
text_button = gr.Button("确定查询")
text_outputs = gr.Textbox(label="您查询的猫的品种的简介如下:",lines=10)
#right
image_results.select(get_selected_image, state_image_list, queue=False) #样例图展示
text_button_img.click(fn=gradio_wrapper, inputs=image_input, outputs=text_output)
#left
text_button.click(fn=wenxin, inputs=text_input ,outputs=text_outputs)
return demo
8.优化后的主函数
if __name__ == "__main__":
with gr.Blocks(css='style.css') as demo:
gr.Markdown(
"# <center> \N{fire} 基于ResNet50的猫十二分类 </center>")
with gr.Tabs():
with gr.TabItem('\N{clapper board} 猫十二分类预测'):
teac_math()
demo.launch()
9.优化后的页面效果展示
--------------HOOK团队出品