点赞再看,养成习惯!觉得不过瘾的童鞋,欢迎关注公众号《机器学习算法工程师》,有非常多大神的干货文章可供学习噢…
前言
这篇文章汇总这个系列之前的博客的所有代码(没有看过的童鞋,最好去看看),小编按照自己构思的比较不错的设计架构来组织代码,把数据预处理模块(data_precessing)、模型模块(model)、工具模块(utils)等等分开,各自封装成类,然后处理任务的流程在main模块中实现。如下图所示:
正文
数据预处理模块
这个模块存放的是自定义的数据集类
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
# just handle one data
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 定长字符识别策略,填充的字符为10,这样不会与有效字符0-9发生碰撞
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (6 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl))
def __len__(self):
return len(self.img_path)
模型模块
这个模块存放的是机器学习模型,这里有两个:自定义的神经网络(用于学习目的);继承预训练模型的神经网络(实操用的)。这里给出后者的代码,包括模型结构、训练、验证以及预测的功能。
class SVHN_Model2(nn.Module):
def __init__(self):
super(SVHN_Model2, self).__init__()
# 继承resnet18
model_conv = models.resnet18(pretrained=True)
# 将resnet18的最后一个池化层修改为自适应的全局平均池化层
model_conv.avgpool = nn.AdaptiveAvgPool2d(1)
# 微调,把fc层删除
model_conv = nn.Sequential(*list(model_conv.children())[:-1])
self.cnn = model_conv
# 自定义fc层
self.fc1 = nn.Linear(512, 11)
self.fc2 = nn.Linear(512, 11)
self.fc3 = nn.Linear(512, 11)
self.fc4 = nn.Linear(512, 11)
self.fc5 = nn.Linear(512, 11)
self.fc6 = nn.Linear(512, 11)
def forward(self, img):
feat = self.cnn(img)
# print(feat.shape)
feat = feat.view(feat.shape[0], -1)
c1 = self.fc1(feat)
c2 = self.fc2(feat)
c3 = self.fc3(feat)
c4 = self.fc4(feat)
c5 = self.fc5(feat)
c6 = self.fc6(feat)
return c1, c2, c3, c4, c5, c6
def mytraining(self, train_loader, criterion, optimizer, device=torch.device('cpu')):
# 切换模型为训练模式
self.train()
train_loss = []
for i, (data, label) in enumerate(train_loader):
c0, c1, c2, c3, c4, c5 = self(data.to(device))
label = label.long().to(device)
loss = criterion(c0, label[:, 0]) + \
criterion(c1, label[:, 1]) + \
criterion(c2, label[:, 2]) + \
criterion(c3, label[:, 3]) + \
criterion(c4, label[:, 4]) + \
criterion(c5, label[:, 5])
loss /= 6
train_loss.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
return round(np.mean(train_loss),4)
def myvalidating(self, val_loader, criterion, device=torch.device('cpu')):
# 切换模型为预测模型
self.eval()
val_loss = []
# 不记录模型梯度信息
with torch.no_grad():
for i, (data, label) in enumerate(val_loader):
c0, c1, c2, c3, c4, c5 = self(data.to(device))
label = label.long().to(device)
loss = criterion(c0, label[:, 0]) + \
criterion(c1, label[:, 1]) + \
criterion(c2, label[:, 2]) + \
criterion(c3, label[:, 3]) + \
criterion(c4, label[:, 4]) + \
criterion(c5, label[:, 5])
loss /= 6
val_loss.append(loss.item())
return round(np.mean(val_loss),4)
def myPredicting(self, test_loader, device=torch.device('cpu')):
# 切换模型为预测模型
self.eval()
is_init = True
# 不记录模型梯度信息
with torch.no_grad():
for i, (data, label) in enumerate(test_loader):
c0, c1, c2, c3, c4, c5 = self(data)
l0 = np.reshape(c0.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
l1 = np.reshape(c1.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
l2 = np.reshape(c2.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
l3 = np.reshape(c3.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
l4 = np.reshape(c4.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
l5 = np.reshape(c5.numpy().argmax(axis=1),(-1,1)) #100x11->100x1
# 合并->100x6
tmp = np.concatenate((l0,l1,l2,l3,l4,l5),axis=1)
if is_init:
pred_labels=tmp
is_init=False
else:
pred_labels = np.concatenate((pred_labels,tmp),axis=0)
return pred_labels
工具模块
这个模块包含一些零零散散地工具方法,都是些静态方法,有数据导入、结果保存等功能。
class Tools:
@staticmethod
def dataFromPath(img_path,label_path=None):
imgs = glob.glob(img_path)
imgs.sort()
if label_path:
label_json = json.load(open(label_path))
labels = [label_json[x]['label'] for x in label_json]
else: #制作假的测试集标签
labels = [[10]]*len(imgs)
return imgs,labels
@staticmethod
def calAcc(pred_label,true_label):
length = len(true_label)
count = 0
for i in range(length):
for j in range(len(true_label[i])):
if true_label[i][j]==pred_label[i][j] or true_label[i][j]==10:
if true_label[i][j]==10:
count+=1
break
else:
break
return round(count/length,4)*100
@staticmethod
def printInfo(epoch,train_loss,val_loss,
best_epoch,best_val_loss,
train_acc='--',val_acc='--',best_val_acc='--'):
print("epoch {}: train_loss {}, train_acc {}; val_loss {}, val_acc {}; "
"(best_epoch,best_val_loss,best_val_acc):({},{},{})".format(
epoch,train_loss,train_acc,val_loss,val_acc,best_epoch,best_val_loss,best_val_acc))
@staticmethod
def submit(demo_submit_path,pred_labels,out_path='Submit_files/'):
submit = pd.read_csv(demo_submit_path)
pred_result = []
for label in pred_labels:
tmp = []
for char in label:
if char!=10:
tmp.append(char)
else:
break
# 意外情况,没有有效字符,默认填充0
if not tmp:
tmp.append(0)
pred_result.append("".join(map(str,tmp)))
# 填充到pd表格
submit['file_code'] = pred_result
# 保存为文件submit.csv
out_path += "submit.csv"
submit.to_csv(out_path,index=False)
主模块
这个模块是程序入口,显式实现了整个字符识别任务地处理逻辑。
if __name__=='__main__':
# 配置环境
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 初始化参数
train_img_path = r'E:\Datas\StreetCharsRecognition\mchar_train\*.png'
train_label_path = r'E:\Datas\StreetCharsRecognition\mchar_train.json'
val_img_path = r'E:\Datas\StreetCharsRecognition\mchar_val\*.png'
val_label_path = r'E:\Datas\StreetCharsRecognition\mchar_val.json'
test_img_path = r'E:\Datas\StreetCharsRecognition\mchar_test_a\*.png'
demo_submit_path = r'E:\Datas\StreetCharsRecognition\mchar_sample_submit_A.csv'
batch_size = 100
epochs = 20
lr = .001
is_predicting = False #默认is_predicting=False, 表明为训练过程
# 训练过程
if not is_predicting:
# 加载数据
train_path,train_label = Tools.dataFromPath(train_img_path,train_label_path)
train_dataset = SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
val_path,val_label = Tools.dataFromPath(val_img_path,val_label_path)
val_dataset = SVHNDataset(val_path, val_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=5,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=5,
)
# 创建模型
model = SVHN_Model1().to(device)
criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
best_epoch, best_loss, best_acc = -1, 1000.0, 0
# 模型训练,并保存最优参数
for epoch in range(epochs):
train_loss = model.mytraining(train_loader, criterion, optimizer, device)
val_loss = model.myvalidating(val_loader, criterion, device)
# 记录下验证集精度
if val_loss < best_loss:
best_epoch, best_loss = epoch, val_loss
# 保存model可学习参数
torch.save(model.state_dict(), 'Model/model.pt')
# 打印相关信息
Tools.printInfo(epoch, train_loss, val_loss,
best_epoch, best_loss)
else:
# 预测过程
test_path, test_label = Tools.dataFromPath(test_img_path)
test_dataset = SVHNDataset(test_path, test_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=5,
)
model = SVHN_Model1()
model.load_state_dict(torch.load("Model/model.pt", map_location='cpu'))
pred_labels = model.myPredicting(test_loader)
Tools.submit(demo_submit_path,pred_labels)
其他
一些自定义地文件目录,Model目录存放训练过程中最优地模型参数,Submit_files目录存放满足可提交格式地预测结果csv文件。
结语
这篇文章贴的代码已经是一份完整的代码啦,有需要地可以去参考文献中的github链接下载代码。这份代码,小编近期将会持续更新,还有很多没讲到地知识点噢。
参考文献
- https://github.com/Ggmatch/CV_StreetCharsRecognition
童鞋们,让小编听见你们的声音,点赞评论,一起加油。