![c1587860982782753c35d71212b07370.png](https://i-blog.csdnimg.cn/blog_migrate/089562d0e5e2e121e1262ee484fadf3e.jpeg)
源码地址:
taki0112/SENet-Tensorflowgithub.com![08b191d284a0348d6da52cbbe5dc7b53.png](https://i-blog.csdnimg.cn/blog_migrate/7dbac46c89cd8626185215821c7538ac.jpeg)
源码中的训练例子是用cifar10 dataset,你自己开始可以跑一下试试看效果如何。
接下来如果是要训练你自己的数据集,我这里用的是最笨的办法,就是直接把自己的数据集换成cifar10的格式,在把相关的参数改成自己的数据就可实现训练。
首先,把自己的数据集构造成cifar10格式。
第一步,构建train、val的lst
'''
creat_list.py
'''
import os
import shutil
import random
woodtrain = open('data/wood_train.lst', 'w')
woodtest = open('data/wood_test.lst', 'w')
savedir = 'data/floor'
dirpath = 'F:/2019_DL_Nets/SENet-Tensorflow/dataset'
filelist = []
for parent, dirs, filenames in os.walk(dirpath):
for subdir in dirs:
labels = dirs.index(subdir)
subfilelist = []
for filename in os.listdir(os.path.join(parent, subdir)):
objfile = os.path.join(parent, subdir, filename)
desfile = os.path.join(savedir, str(labels) + "_" + filename)
shutil.copyfile(objfile, desfile)
subfilelist.append(str(labels) + "_"+ filename)
random.shuffle(subfilelist)
filelist.append(subfilelist)
ratio = 0.7
Tr_list = []
Te_list = []
for namelist in filelist:
train_num = int(ratio*len(namelist))
trainlist = namelist[0:train_num]
for info in trainlist:
Tr_list.append(info)
vallist = namelist[train_num:]
for info in vallist:
Te_list.append(info)
random.shuffle(Tr_list)
for inf in Tr_list:
woodtrain.write(inf + " " + inf[0][0] + "n")
random.shuffle(Te_list)
for inf in Te_list:
woodtest.write(inf + " " + inf[0][0] + "n")
woodtest.close()
woodtrain.close()
我的原始数据格式是
![f86dcfa30ded7e69e28b5bd73d734905.png](https://i-blog.csdnimg.cn/blog_migrate/e6c8fc8c3c640bded5a91625b49930d4.png)
![45927521c7ec6514304e2e16400a5018.png](https://i-blog.csdnimg.cn/blog_migrate/523cce03a9cc8e15e8a902746097e5aa.jpeg)
![34b06b73a1321933e0ab9f48c6d83a39.png](https://i-blog.csdnimg.cn/blog_migrate/d3ea6bb2bac9ae42956ca7a4dbdd7476.jpeg)
在每一类下都是按照顺序命名的