话不多说,上代码
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Embedding, Dot, Flatten, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)
embedding_size = 128
batch_size = 600000
# data_num = 23740550
# data_num = 900000000
data_num = 12617678
lr = 0.01
epochs = 1000
save_interval = 10
class CF(Model):
def __init__(self, num_users, num_items, embedding_size, **kwargs):
super(CF, self).__init__(**kwargs)
self.user_embedding = Embedding(num_users, embedding_size,
embeddings_initializer='he_normal',
embeddings_regularizer=l2(1e-6))
self.item_embedding = Embedding(num_items, embedding_size,
embeddings_initializer='he_normal',
embeddings_regularizer=l2(1e-6))
self.reshape = Reshape((embedding_size, 1))
self.dot = Dot(axes=1)
self.flatten = Flatten()
def call(self, inputs):
user_vector = self.user_embedding(inputs[:, 0])
item_vector = self.item_embedding(inputs[:, 1])
user_vector = self.reshape(user_vector)
item_vector = self.reshape(item_vector)
dot_user_item = self.dot([user_vector, item_vector])
return self.flatten(dot_user_item)
def get_user_embedding(self, user_id):
user_vector = self.user_embedding(np.array([user_id]))
return user_vector.numpy()
def get_item_embedding(self, item_id):
item_vector = self.item_embedding(np.array([item_id]))
return item_vector.numpy()
# Load all user_id and item_id
with open('uids_small.txt') as f:
uids = f.read().splitlines()
uids = [uid.strip() for uid in uids]
with open('items_small.txt') as f:
items = f.read().splitlines()
items = [item.strip() for item in items]
# Create dictionaries for fast lookup
uid_to_index = {uid: idx for idx, uid in enumerate(uids)}
item_to_index = {item: idx for idx, item in enumerate(items)}
def data_generator(file_path, batch_size):
while True:
chunk = pd.read_csv(file_path, chunksize=batch_size, sep='\t', header=None, names=['user_id', 'item_id', 'rating'],dtype={'user_id': str})
for df in chunk:
df['user_id'] = df['user_id'].map(uid_to_index)
df['item_id'] = df['item_id'].map(item_to_index)
# 将 rating 列中大于 1 的值改为 1
df['rating'] = np.clip(df['rating'].values, 0, 2)
yield df[['user_id', 'item_id']].values, df['rating'].values
def tf_data_generator(file_path, batch_size):
dataset = tf.data.Dataset.from_generator(
lambda: data_generator(file_path, batch_size),
output_signature=(
tf.TensorSpec(shape=(None, 2), dtype=tf.int32),
tf.TensorSpec(shape=(None,), dtype=tf.float32)
)
)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return dataset
# Define model
num_users = len(uids)
num_items = len(items)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = CF(num_users=num_users, num_items=num_items, embedding_size=embedding_size)
model.build(input_shape=(None, 2)) # You need to specify the input shape
model.summary()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss='mean_squared_error')
# model.compile(optimizer=tf.keras.optimizers.Ftrl(learning_rate=lr),
# loss='mean_squared_error')
class SaveModelAndEmbeddings(tf.keras.callbacks.Callback):
def __init__(self, model, uids, items, save_interval=10):
super(SaveModelAndEmbeddings, self).__init__()
self.model = model
self.uids = uids
self.items = items
self.save_interval = save_interval
def on_epoch_end(self, epoch, logs=None):
if (epoch + 1) % self.save_interval == 0:
# Save embeddings
user_embeddings = self.model.user_embedding.get_weights()[0]
item_embeddings = self.model.item_embedding.get_weights()[0]
np.savez(f'user_embeddings.npz', uids=self.uids, user_embeddings=user_embeddings)
np.savez(f'item_embeddings.npz', items=self.items, item_embeddings=item_embeddings)
# Train model
steps_per_epoch = data_num // batch_size
save_callback = SaveModelAndEmbeddings(model, uids, items, save_interval=save_interval)
model.fit(tf_data_generator('ratings_small.txt', batch_size), steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[save_callback])
# Save final embeddings
user_embeddings = model.user_embedding.get_weights()[0]
item_embeddings = model.item_embedding.get_weights()[0]
np.savez('user_embeddings.npz', uids=uids, user_embeddings=user_embeddings)
print("user_embeddings saved")
np.savez('item_embeddings.npz', items=items, item_embeddings=item_embeddings)
print("item_embeddings saved")