**
电影海报的多标签分类
**
本文代码参考
https://www.kaggle.com/koushikk/multilabel-classification-using-cnn
https://www.depends-on-the-definition.com/classifying-genres-of-movies-by-looking-at-the-poster-a-neural-approach/
https://blog.csdn.net/qq_34792438/article/details/97140512
本文源代码:https://github.com/gj2528/Classify-Poster
数据来源
Kaggle上获得(IMDB有关于每部电影的详细信息):
原始kaggle海报数据集:
https://www.kaggle.com/neha1703/movie-genre-from-its-poster?tdsourcetag=s_pctim_aiomsg
(海报个数997个)
海报大数据集:
https://www.kaggle.com/neha1703/movie-genre-from-its-poster/discussion/35485
(海报个数39515个)
海报类型29
海报类型个数
海报个数为997个的海报类型个数:
海报个数为39515个的海报类型个数:
一些全局变量
#全局变量,可以根据自己需要改
BATCH_SIZE = 128
SIZE = (150, 101)
n = 900 #以小数据量为例
n_test = 90 #以小数据量为例
数据处理
得到数据:
path = 'posters'
data = pd.read_csv("MovieGenre.csv", encoding="ISO-8859-1")
print(data.head())
#Next, we load the movie posters.
image_glob = glob.glob(path + "/" + "*.jpg")
print(image_glob)
img_dict = {
}
将得到的图片路径写入tet文件:
txtName = "dataset/SampleMoviePoster.txt"
f = open(txtName, "w")
for image_single in image_glob:
result = image_single +'\n'
f.write(result)
f.close()
print("writefileDONE")
得到图片id:
def get_id(filename):
index_s = filename.rfind("\\") + 1
index_f = filename.rfind(".jpg")
return filename[index_s:index_f]
得到图片列表:
for fn in image_glob:
try:
img_dict[get_id(fn)] = scipy.misc.imread(fn)
except:
pass
显示图片:
def show_img(id):
title = data[data["imdbId"] == int(id)]["Title"].values[0]
genre = data[data["imdbId"] == int(id)]["Genre"].values[0]
plt.imshow(img_dict[id])
plt.title("{} \n {}".format(title, genre))
plt.show()
#show_img('3772')
一个简洁的小预处理函数来缩放图像:
#一个简洁的小预处理函数来缩放图像......
def preprocess(img, size=(150, 101)):
img = scipy.misc.imresize(img, size)
img = img.astype(np.float32)
img = (img / 127.5) - 1.
return img
准备数据:
def prepare_data(data, img_dict, size=(150, 101)):
print("Generation dataset...")
dataset = []
y = []
ids = []
label_dict = {
"word2idx": {
}, "idx2word": []}
idx = 0
genre_per_movie = data["Genre"].apply(lambda x: str(x).split("|"))
for l in [g for d in genre_per_movie for g in d]:
#print("l",l)
if l in label_dict["idx2word"]:
pass
else:
label_dict["idx2word"].append(l)
label_dict["word2idx"][l] = idx
idx += 1
n_classes = len(label_dict["idx2word"])
print("identified {} classes".format(n_classes))
n_samples = len(img_dict)
print("got {} samples".format(n_samples))
for k in img_dict:
try:
g = data[data["imdbId"] == int(k)]["Genre"]