1.设置Python环境,导入必要的库和模块
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import warnings
warnings.filterwarnings('ignore')
import re
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import tensorflow as tf
import keras
from keras import layers
from keras.applications import efficientnet
from keras.layers import TextVectorization
from keras.preprocessing.image import load_img, img_to_array
from sklearn.model_selection import train_test_split
from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm_notebook
from collections import Counter
import matplotlib
2.设置机器学习实验的环境和参数
# Path to the images
IMAGES_PATH = "/hy-tmp/flickr30k_images/flickr30k_images"
# Path to the captions
CAPTIONS_PATH = "/hy-tmp/flickr30k_images/results.csv"
# Desired image dimensions
IMAGE_SIZE = (299, 299)
# Fixed length allowed for any sequence
SEQ_LENGTH = 24
# Vocabulary size
VOCAB_SIZE = 13000
# Dimension for the image embeddings and token embeddings
EMBED_DIM = 512
# Per-layer units in the feed-forward network
FF_DIM = 512
# Batch size
BATCH_SIZE = 512
# Number of epochs
EPOCHS = 50
3.加载和处理图像描述数据集,并对数据集进行分割以用于训练、验证和测试
# Loads captions (text) data and maps them to corresponding images.
def load_captions_data(filename):
with open(filename) as caption_file:
caption_data = caption_file.readlines()[1:]
caption_mapping = {}
text_data = []
images_to_skip = set()
for line in caption_data:
line = line.rstrip("\n")
# Each image is repeated five times for the five different captions.
# Image name and captions are separated using a comma
try:
img_name, _, caption = line.split("| ")
# There is one row in the dataset which causes ValueError when splitting.
# Handling the error:
except ValueError:
img_name, caption = line.split("| ")
caption = caption[4:]
img_name = os.path.join(IMAGES_PATH, img_name.strip())
# Removing caption that are either too short to too long
tokens = caption.strip().split()
if len(tokens) < 4 or len(tokens) > SEQ_LENGTH:
images_to_skip.add(img_name)
continue
if img_name.endswith("jpg") and img_name not in images_to_skip:
# A start and an end token must be added to each caption
caption = "<start> " + caption.strip() + " <end>"
text_data.append(caption)
if img_name in caption_mapping:
caption_mapping[img_name].append(caption)
else:
caption_mapping[img_name] = [caption]
for img_name in images_to_skip:
if img_name in caption_mapping:
del caption_mapping[img_name]
return caption_mapping, text_data
# Splits the dataset into training, validation, and test sets
def train_val_split(caption_data, validation_size=0.2, test_size=0.02, shuffle=True):
# Getting the list of all image names
all_images = list(caption_data.keys())
# Shuffle if necessary
if shuffle:
np.random.shuffle(all_images)
train_keys, validation_keys = train_test_split(all_images, test_size=validation_size, random_state=42)
validation_keys, test_keys = train_test_split(validation_keys, test_size=test_size, random_state=42)
training_data = {img_name: caption_data[img_name] for img_name in train_keys}
validation_data = {img_name: caption_data[img_name] for img_name in validation_keys}
test_data = {img_name: caption_data[img_name] for img_name in test_keys}
# Return the splits
return training_data, validation_data, test_data
# Loading the dataset
captions_mapping, text_data = load_captions_data(CAPTIONS_PATH)
# Spliting the dataset
train_data, validation_data, test_data = train_val_split(captions_mapping)
print(f"Total number of samples: {len(captions_mapping)}")
print(f"----> Number of training samples: {len(train_data)}")
print(f"----> Number of validation samples: {len(validation_data)}")
print(f"----> Number of test samples: {len(test_data)}")
4.文本标准化和数据增强
def custom_standardization(input_string):
# Lowercasing all of the captions
lowercase = tf.strings.lower(input_string)
# Charecters to remove
strip_chars = "!\"#$%&'()*+,-./:;=?@[\]^_`{|}~1234567890"
return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")
# Defining the vectorizer
vectorization = TextVectorization(
# Numbe