https://www.bilibili.com/video/BV1eq4y1H75J/?spm_id_from=333.999.0.0&vd_source=5652a3d62a700fbd74b050faab8a17f5
data.py
import os.path
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self,root,is_train=True):
self.dataset = [] #定义列表储存数据
dir = 'train' if is_train else "test"
sub_dir = os.path.join(root,dir)
img_list = os.listdir(sub_dir)
for i in img_list:
img_dir = os.path.join(sub_dir,i)
self.dataset.append(img_dir)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index): #数据被调用时就会触发getitem
data = self.dataset[index]
img = cv2.imread(data)/255 #HWC→012 /255归一化
# print(img.shape)
# new_img = np.transpose(img,(2,0,1)) #CHW
new_img = torch.tensor(img).permute(2,0,1) #np和torch可做选择
img = new_img
# print(new_img.shape)
data_list = data.split('.')
print(data_list)
label = int(data_list[1])
position = data_list[2:6]
position = [int(i)/300 for i in position] #图像归一化
sort = int(data_list[6])-1 #没有小黄人为0 冗余减为-1
return np.float32(img),np.float32(label),np.float32(position),int(sort) #torch常用32位
if __name__ == '__main__':
# data = MyDataset('Lab_C.a/images',is_train=True)
data = MyDataset('yellow_data', is_train=False) #测试集
for i in data:
print(i)
net.py
import torch
from torch import nn
class MyNet(nn.Module):
def __init__(self):
super(MyNet,self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(3,11,3),
nn.LeakyReLU(),
nn.MaxPool2d(3),
nn.Conv2d(11,22,3),
nn.LeakyReLU(),
nn.MaxPool2d(2),
nn.Conv2d(22,32,3),
nn.LeakyReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32,64,3),
nn.LeakyReLU(),
nn.Conv2d(64,128,3),
nn.LeakyReLU(),
)
self.label_layers = nn.Sequential(
nn.Conv2d(128,1,19),
nn.ReLU(),
)
self.position_layers = nn.Sequential(
nn.Conv2d(128,4,19)
)
self.sort_layer = nn.Sequential(
nn.Conv2d(128,20,19),
nn.LeakyReLU()
)
def forward(self,x):
out = self.layers(x)
label = self.label_layers(out)
# 降维两次
label = torch.squeeze(label,dim=2) #降维第二个位置
label = torch.squeeze(label, dim=2)#再次降维第二个位置
#根据train print(out_label.shape)维度不等于 print(label.shape)再次降维
label = torch.squeeze(label,dim=1)
position = self.position_layers(out)
position = torch.squeeze(position,dim=2)
position = torch.squeeze(position,dim=2)
sort = self.sort_layer(out)
sort = torch.squeeze(sort,dim=2)
sort = torch.squeeze(sort,dim=2)
return label,position,sort
if __name__ == '__main__':
net = MyNet()
x = torch.randn(3,3,300,300)
print(net(x)[0].shape) #0是label
print(net(x)[1].shape)
print(net(x)[2].shape)
train.py
import os.path
from datetime import datetime
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch import optim
from net import MyNet
from data import MyDataset
from torch import nn
DEVICE = 'cuda'
class Train:
def __init__(self,root,weight_path):
self.summaryWriter = SummaryWriter('logs')
self.train_dataset = MyDataset(root=root,is_train=True)
self.test_dataset = MyDataset(root=root, is_train=False)
self.train_dataloader = DataLoader(self.train_dataset,batch_size=64,shuffle=True)
self.test_dataloader = DataLoader(self.test_dataset, batch_size=64, shuffle=True)
self.net = MyNet().to(DEVICE)
if os.path.exists(weight_path):
self.net.load_state_dict(torch.load(weight_path ) )
self.optimizer = optim.Adam(self.net.parameters())
self.label_loss_fun = nn.BCEWithLogitsLoss() #因为博主在构建网络时没有logist所以选这个带的损失
self.position_loss_fun = nn.MSELoss()
self.sort_loss_fun = nn.CrossEntropyLoss()
self.train = True
self.test = True
#call函数可将一个类实例变成一个可调用的对象,即可以像函数一样调用这个类
def __call__(self):
index1,index2 = 0,0
for epoch in range(100):
if self.train:
for i,(img,label,position,sort) in enumerate(self.train_dataloader):
#print(img,label,position,sort)
self.net.train()
img, label, position,sort = img.to(DEVICE),label.to(DEVICE),position.to(DEVICE),sort.to(DEVICE)
# print(img.shape)
# print(label.shape)
# print(position.shape)
# print(sort.shape)
out_label,out_position,out_sort = self.net(img)
# print(out_label,out_position,out_sort)
# print('--------')
# print(out_label.shape) #形状相同才能计算损失
# print(out_position.shape)
# print(out_sort.shape)
out_label_loss = self.label_loss_fun(out_label,label)
out_position_loss = self.position_loss_fun(out_position,position)
#注意之前将sort-1 所以这里存在-1会报错
sort = sort[torch.where(sort>=0)]
out_sort = out_sort[torch.where(sort>=0)] #out_sort根据sort取值重新计算出新的值并进行赋值
out_sort_loss = self.sort_loss_fun(out_sort,sort) #博主这里说输入sort为标量但经过CrossEntropyLoss所以不需要再操作
# print(out_label_loss)
# print(out_position_loss)
# print(out_sort_loss)
train_loss = out_label_loss+out_sort_loss+out_position_loss #损失和进行优化(局部最优不一定是全局最优)
self.optimizer.zero_grad()
train_loss.backward()
self.optimizer.step()
if i%10 == 0:
print(f'train_loss{i}==================================>>',train_loss.item())
self.summaryWriter.add_scalar('train_loss',index1)
index1 +=1
data_time = str(datetime.now()).replace(':','-').replace('.','-').replace(':','-')
torch.save(self.net.state_dict(),f'param/{data_time}-{epoch}.pt')
if self.test:
sum_sort_acc,sum_label_acc = 0,0
for i,(img,label,position,sort) in enumerate(self.test_dataloader):
# self.net.train()
img, label, position,sort = img.to(DEVICE),label.to(DEVICE),position.to(DEVICE),sort.to(DEVICE)
out_label,out_position,out_sort = self.net(img)
out_label_loss = self.label_loss_fun(out_label,label)
out_position_loss = self.position_loss_fun(out_position,position)
sort = sort[torch.where(sort>=0)]
out_sort = out_sort[torch.where(sort>=0)]
out_sort_loss = self.sort_loss_fun(out_sort,sort)
test_loss = out_label_loss+out_sort_loss+out_position_loss
out_label = torch.tensor(torch.sigmoid(out_label))
out_label[torch.where(out_label>=0.5)] = 1
out_label[torch.where(out_label < 0.5)] = 0
out_sort = torch.argmax(torch.softmax(out_sort,dim=1))
label_acc = torch.mean(torch.eq(out_label,label).float())
sum_label_acc += label_acc
sort_acc = torch.mean(torch.eq(out_sort,sort).float())
sum_sort_acc +=sort_acc
if i%10 == 0:
print(f'test_loss{i}=======================================>>',test_loss.item())
self.summaryWriter.add_scalar('test_loss',index2)
index2 +=1
avg_sort_acc = sort_acc/i
avg_label_acc = sum_label_acc/i
print(f'avg_label_acc {epoch}========================================>>', avg_label_acc)
self.summaryWriter.add_scalar('avg_label_acc', avg_label_acc, epoch)
print(f'avg_sort_acc {epoch}========================================>>',avg_sort_acc)
self.summaryWriter.add_scalar('avg_sort_acc',avg_sort_acc,epoch)
if __name__ == '__main__':
train = Train('yellow_data',weight_path='')
train()
predict.py
import os
import torch
import cv2
from net import MyNet
if __name__ == '__main__':
img_name = os.listdir('yellow_data/test')
for i in img_name:
img_dir = os.path.join('yellow_data/test',i)
img = cv2.imread(img_dir)
position1 = (i.split('.')[2:6])
position = [int(j) for j in position1]
sort = i.split('.')[6]
cv2.rectangle(img,(position[0],position[1]),(position[2],position[3]),(0,255,0),thickness=3)
cv2.putText(img,sort,(position[0],position[1]-1),cv2.FONT_HERSHEY_SIMPLEX,1,(255,0,0),thickness=1)
##################################################################
Model = MyNet()
Model.load_state_dict(torch.load('param/2024-04-09 15-33-17-230540-99.pt'))
new_img = torch.tensor(img).permute(2,0,1)
new_img = torch.unsqueeze(new_img,dim=0)/255 #增加维度0
out_label,out_position,out_sort = Model(new_img)
out_position = out_position[0] * 300
out_position = [int(i) for i in out_position]
out_label = torch.sigmoid(out_label)
out_sort = torch.argmax(torch.softmax(out_sort,dim=1))
if out_label > 0.5:
cv2.rectangle(img,(out_position[0], out_position[1]), (out_position[2], out_position[3]), (0,0,255), thickness=3)
cv2.putText(img, str(out_sort.item()), (position[0], position[1]+1), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), thickness=1)
cv2.imshow('img',img)
cv2.waitKey(500)
cv2.destroyAllWindows()