代码总体架构
代码总体分为6个文件,分别为:dataset.py, main.py, config, utils.py, prediction.py,Network.py。每个文件对应以下功能
文件 | 功能 |
---|---|
dataset.py | 用来对数据进行处理 |
main.py | 用来进行训练模型 |
Network.py | 用来写网络 |
config | 内含init.yaml, 用来写参数 |
utils.py | 写一些辅助函数,例如获得config的函数等 |
prediction.py | 用来检验模型 |
具体文件介绍
对六个部分进行具体介绍
1.Network.py
代码如下:(*需要查阅相关论文,进行进一步改进)
class MyModel(nn.Module):
def __init__(self, classnum=1000, feature=False):
# classnum = dataloader.dataset.class_num = 227
super(MyModel, self).__init__()
self.classnum = classnum
self.feature = feature
# input = B*3*112*96
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3,stride= 2, padding=1) ,
nn.BatchNorm2d(64),
nn.PReLU(64)
) # =>B*64*56*48
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,stride=1, padding=1) ,
nn.BatchNorm2d(64),
nn.PReLU(64)
) # =>懒得算了
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,stride=1, padding=1) ,
nn.BatchNorm2d(64),
nn.PReLU(64)
)
self.layer4 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,stride=2, padding=1) ,
nn.BatchNorm2d(128),
nn.PReLU(128)
)# =>B*128*28*24
def forward(self, x, target=None):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1)
2.config
内含init.yaml
包含了各种参数以及出现频率较高的路径
3.utils.py
内含Get_Last_Model函数以及Get_Config函数,用以辅助,具体功能如下:
函数 | 功能 |
---|---|
Get_Last_Model | 获取保存在模型保存文件内的最后一个模型 |
Get_Config | 获取Config内参数 |
4.dataset.py
主要用来获取数据并对数据进行各种处理,并包含一些与数据有关的简单操作。主要有三个部分: init、len、getitem, 主要对应以下功能:
函数 | 功能 |
---|---|
init | 初始化,主要用来导入数据,并设置 了一些属性,包括对应的图片数量、图片对应人名、图片处理方式 |
len | 获取长度的方法 |
getitem | 获取索引对应的处理过后的图片以及对应标签的方法 |
5.main.py
主要用来训练模型。所定义的函数以及对应功能为:
函数 | 功能 |
---|---|
Get_Model | 用来获得在Network.py中所设置的模型 |
Get_Optimizer | 用来获得参数优化器 |
Get_Scheduler | 用来获得Schedule从而获得动态学习率 |
Train | 用来训练模型 |
总体的思路为:
- 创建parser,设置各种参数
- 设置模型保存的各种路径
- 设置Dataloader
- 如果意外中断,从中断处继续训练,否则,设置optimizer、scheduler,进行训练
6.prediction.py
主要用来进行模型准确度测试。所定义的函数以及对应功能为:
函数 | 功能 |
---|---|
Unresticted_Iter | 迭代器,用来获得所需要的两个图片及其对应的相似度(0/1) |
Predict | 加载模型,利用余弦距离计算迭代器所获得的两个图片的相似度,返回预测相似度以及实际相似度 |
KFold | 按KFold方法把图片分好 |
Eval_Threshold | 通过精度评价阈值的好坏 |
Find_Best_Threshold | 寻找判断Predict的最佳阈值 |
总体的思路为:
- 创建parser,设置各种参数
- 设置各种路径
- Predict函数进行预测
- Eval函数对结果进行评价