CNN项目1-ResNet模型图像分类之工业缺陷检测
1.类别文件
{
"0": "In",
"1": "Sc",
"2": "Cr",
"3": "PS",
"4": "RS",
"5": "Pa"
}
2.数据集
- my_dataset.py 创建自己数据集, 集成自pytorch的dataset类, 必须实现__len__和__getitem__
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2 as cv
defect_labels = ['In', 'Sc', 'Cr', 'PS', 'RS', 'Pa']
class SurfaceDefectDataset(Dataset):
def __init__(self, root_dir):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
transforms.Resize((200, 200))])
img_files = os.listdir(root_dir)
self.defect_types = []
self.images = []
for file_name in img_files:
defect_class = file_name.split('_')[0]
defect_index = defect_labels.index(defect_class)
self.images.append(os.path.join(root_dir, file_name))
self.defect_types.append(defect_index)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image_path = self.images[idx]
img = cv.imread(image_path)
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
sample = {'image': self.transform(img), 'defect': self.defect_types[idx]}
return sample
if __name__ == '__main__':
ds = SurfaceDefectDataset('./enu_surface_defect/train')
print(len(ds))
print(ds[0]['image'].shape, ds[0]['defect'])
dl = DataLoader(ds, batch_size=8, shuffle=True, num_workers=8)
sample = next(iter(dl))
print(type(sample))
print(sample['image'].shape)
3.模型
import torch
import torchvision
class SurfaceDectectResNet(torch.nn.Module):
def __init__(self, num_classes=1000):
super().__init__()
self.cnn_layers = torchvision.models.resnet18(pretrained=True)
in_features = self.cnn_layers.fc.in_features
self.cnn_layers.fc = torch.nn.Linear(in_features, num_classes)
def forward(self, x):
out = self.cnn_layers(x)
return out
4.训练及校验
import os
import json
import sys
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from tqdm import tqdm
from my_dataset import SurfaceDefectDataset
from my_dataset import defect_labels
from model import SurfaceDectectResNet
def main():
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'using {device}')
train_dataset = SurfaceDefectDataset('./enu_surface_defect/train')
train_num = len(train_dataset)
cla_dict = dict((i, label) for i, label in enumerate(defect_labels))
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
print(f'using {nw} dataloader workers every process')
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,
shuffle=True,num_workers=nw)
validate_dataset = SurfaceDefectDataset('./enu_surface_defect/test')
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size,
shuffle=True,num_workers=nw)
net = SurfaceDectectResNet(num_classes=6)
net.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
epochs = 10
save_path = './model.pth'
best_acc = 0.0
train_steps = len(train_loader)
for epoch in range(epochs):
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data['image'], data['defect']
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_fn(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = f'train epoch[{epoch + 1}/{epochs}] loss:{loss:.3f}'
net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data['image'], val_data['defect']
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_accuracy = acc / val_num
print(f'[epoch {epoch + 1} train_loss: {running_loss / train_steps:.3f},'
f'val_accuracy:{val_accuracy:.3f}')
if val_accuracy > best_acc:
best_acc = val_accuracy
torch.save(net.state_dict(), save_path)
if __name__ == '__main__':
main()
5.预测
import os
import json
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import cv2 as cv
from model import SurfaceDectectResNet
def main():
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
transforms.Resize((200, 200))])
img_path = 'Cr_1.bmp'
assert os.path.exists(img_path), f'{img_path} does not exist'
img = cv.imread(img_path)
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
plt.imshow(img)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
json_path = './class_indices.json'
assert os.path.exists(json_path), f'{json_path} does not exist'
with open(json_path, 'r') as f:
class_dict = json.load(f)
model = SurfaceDectectResNet(num_classes=6).to(device)
weights_path = './model.pth'
assert os.path.exists(weights_path), f'{weights_path} does not exist'
model.load_state_dict(torch.load(weights_path))
model.eval()
with torch.no_grad():
output = model(img.to(device))
output = torch.squeeze(output).cpu()
predict = torch.softmax(output, dim=0)
predict_class = torch.argmax(predict).numpy()
print_res = f'class: {class_dict[str(predict_class)]}, prob:{predict[predict_class].numpy()}'
plt.title(print_res)
plt.show()
if __name__ == '__main__':
main()