通过上篇文章,我已经获取了天猫商城某款马桶评论区的图片。但是,用户上传的评论图片并不都是马桶图片,如下所示:
其中有很多的脏数据。一个简单的笨方法是一个个去选,但是作为一个程序员,当然是应用程序自动去清洗脏数据。
这里使用ImageNet预训练好的Resnet50模型去识别图片,并将脏数据挑选出来。以下为源代码:
import resnet50
import os
import shutil
pathlist = os.listdir('data_grab')
count = 0
for i in pathlist:
pathlist[count] = os.path.join('data_grab', pathlist[count])
count+=1
predictions = resnet50.predict(pathlist)
count = 0
for i in predictions:
flag = 0
if i!=None:
for top in range(3):
name = i[top][1]
if name=='toilet_seat':
flag = 1
if flag==1:
shutil.copyfile(pathlist[count], 'cleandata//'+pathlist[count][-9:])
else:
shutil.copyfile(pathlist[count], 'dirtydata//'+pathlist[count][-9:])
count+=1
其中,resnet50的源代码如下:
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
def predict(img_paths):
model = ResNet50(weights='imagenet')
m = len(img_paths)
predictions = [None]*m
count = 0
for i in img_paths:
print("load file: ", str(i))
try:
img = image.load_img(i, target_size=(224, 224))
img = image.img_to_array(img)
x = np.expand_dims(img, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
preds = decode_predictions(preds, top=3)[0]
predictions[count] = preds
count+=1
except:
count+=1
continue
return predictions
if __name__=="__main__":
import os
pathlist = os.listdir('data_grab')
count = 0
for i in pathlist:
pathlist[count] = os.path.join('data_grab', pathlist[count])
count+=1
predictions = predict(pathlist)
'''
predictions[0]
[('n04447861', 'toilet_seat', 0.87126261), ('n04179913', 'sewing_machine', 0.027522231), ('n15075141', 'toilet_tissue', 0.020023976)]
'''
通过上面的数据清洗,得到的干净数据如下:
相应的脏数据如下:
通过resnet50,基本上可以把脏数据清洗干净,可见深度学习的威力。但是,还是存在一些错误,这只能通过人工进一步清洗了。不过工作量已经少了很多。