论文代码阅读:TGN模型训练阶段代码理解

TGN模型训练阶段代码理解

论文信息

论文链接:https://arxiv.org/abs/2006.10637

GitHub: https://github.com/twitter-research/tgn?tab=readme-ov-file

年份:2020

代码过程手绘

微信图片_20231210165320

微信图片_20231210165409

代码训练过程

pos_prob, neg_prob = tgn.compute_edge_probabilities(sources_batch, destinations_batch, negatives_batch,timestamps_batch, edge_idxs_batch, NUM_NEIGHBORS)

函数compute_edge_probabilities

def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
                               edge_idxs, n_neighbors=20):
    """
    Compute probabilities for edges between sources and destination and between sources and
    negatives by first computing temporal embeddings using the TGN encoder and then feeding them
    into the MLP decoder.
    :param destination_nodes [batch_size]: destination ids
    :param negative_nodes [batch_size]: ids of negative sampled destination
    :param edge_times [batch_size]: timestamp of interaction
    :param edge_idxs [batch_size]: index of interaction
    :param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
    layer
    :return: Probabilities for both the positive and negative edges
    source_nodes 源节点id列表
    destination_nodes 目标节点id列表
    negative_nodes 负采样节点id列表
    edge_times 源节点列表中的节点与目标节点列表中的节点发生关系时的时间
    edge_idxs 边的编号
    """
    n_samples = len(source_nodes)
    # compute_temporal_embeddings
    source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors)

    score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),
                                torch.cat([destination_node_embedding,
                                           negative_node_embedding])).squeeze(dim=0)
    pos_score = score[:n_samples]
    neg_score = score[n_samples:]

    return pos_score.sigmoid(), neg_score.sigmoid()
compute_temporal_embeddings

这个方法的目的是计算时间嵌入

def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
                                edge_idxs, n_neighbors=20):
    """
    Compute temporal embeddings for sources, destinations, and negatively sampled destinations.
	这个方法的目的是计算时间嵌入
    source_nodes [batch_size]: source ids.
    :param destination_nodes [batch_size]: destination ids
    :param negative_nodes [batch_size]: ids of negative sampled destination
    :param edge_times [batch_size]: timestamp of interaction
    :param edge_idxs [batch_size]: index of interaction
    :param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
    layer
    :return: Temporal embeddings for sources, destinations and negatives
    """
	
    # n_samples 表示源节点有多少个
    n_samples = len(source_nodes)
    # nodes是所有的节点这个batch_size中所有的节点id, size=200*3=600
    nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])
    # positives 是将源节点和目标节点和在一切,前200个是源节点的node_id, 后200个是目标节点的node_id
    positives = np.concatenate([source_nodes, destination_nodes])
    # timestamps shape=200*3 edge_times 是发生交互的时间
    timestamps = np.concatenate([edge_times, edge_times, edge_times])
	# edge_times shape = batch_size 是源节点和目的节点发生的时间
    memory = None
    time_diffs = None
    if self.use_memory:
        if self.memory_update_at_start: # 是不是刚开始使用记忆
            # n_nodes 表示的是图中一共有多少个节点 9228
            # 记忆列表 self.memory.messages 当前状态一定为空
            # 在这个地方出来的memory是最新的memory,是根据节点的messages信息进行更新的,在代码中会取该节点messages列表中最新的那一个
            memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
                                                          self.memory.messages)  # memory shape = [n_nodes(9228), memory_dimension(172)] last_update shape [n_nodes(9228)]
        else:
            memory = self.memory.get_memory(list(range(self.n_nodes)))
            last_update = self.memory.last_update
		
        # ===================================== 下面这些都是处理单个节点的信息 ==============================
		 # 计算节点内存最后一次更新的时间与我们希望计算该节点嵌入的时间之间的差异。
        # source_time_diffs shape [batch_size]
        source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[source_nodes].long() 
        # 这是标准化操作
        source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src 
        # destination_time_diffs shape [batch_size]
        destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[destination_nodes].long()
        # 这是标准化操作
        destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst

        negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
            negative_nodes].long()
        negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst

        # 时间差
        time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
                               dim=0)

    # Compute the embeddings using the embedding module
    # self.embedding_module 在下面所示
    # 1. 先是 self.embedding_module = get_embedding_module()
    """
    memory 记忆对象
    nodes 是一个结合了源节点目的节点和负采样节点的node_id列表
    timestamps 200*3的时间列表
    self.n_layers 递归的层数 这里为2
    n_neighbors 选取多少个邻居节点 这里是10
    time_diffs 标准化过后的时间差
    """
    # node_embedding shape [600, 172] 融合了节点的特征和邻居其余边的特征
    node_embedding = self.embedding_module.compute_embedding(memory=memory,
                                                             source_nodes=nodes,
                                                             timestamps=timestamps,
                                                             n_layers=self.n_layers,
                                                             n_neighbors=n_neighbors,
                                                             time_diffs=time_diffs)
	# 然后去获取不同列表的节点特征 
    source_node_embedding = node_embedding[:n_samples]
    destination_node_embedding = node_embedding[n_samples: 2 * n_samples]
    negative_node_embedding = node_embedding[2 * n_samples:]

    if self.use_memory:
        # 进行记忆力更新
        if self.memory_update_at_start:
            # Persist the updates to the memory only for sources and destinations (since now we have
            # new messages for them)
            self.update_memory(positives, self.memory.messages)

            assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
                "Something wrong in how the memory was updated"

            # Remove messages for the positives since we have already updated the memory using them
            # 记忆已经更新了,那么对于每个信息就即为空
            self.memory.clear_messages(positives)

        unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
                                                                      source_node_embedding,
                                                                      destination_nodes,
                                                                      destination_node_embedding,
                                                                      edge_times, edge_idxs)
        unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
                                                                                destination_node_embedding,
                                                                                source_nodes,
                                                                                source_node_embedding,
                                                                                edge_times, edge_idxs)
        if self.memory_update_at_start:
            # 存储信息
            self.memory.store_raw_messages(unique_sources, source_id_to_messages)
            self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
        else:
            self.update_memory(unique_sources, source_id_to_messages)
            self.update_memory(unique_destinations, destination_id_to_messages)

        if self.dyrep:
            source_node_embedding = memory[source_nodes]
            destination_node_embedding = memory[destination_nodes]
            negative_node_embedding = memory[negative_nodes]

    return source_node_embedding, destination_node_embedding, negative_node_embedding
update_memory
def update_memory(self, nodes, messages):
    # Aggregate messages for the same nodes
    # self.message_aggregator -> LastMessageAggregator
    unique_nodes, unique_messages, unique_timestamps = \
        self.message_aggregator.aggregate(
            nodes,
            messages)

    if len(unique_nodes) > 0:
        unique_messages = self.message_function.compute_message(unique_messages)

    # Update the memory with the aggregated messages
    # 聚合完了就去更新
    self.memory_updater.update_memory(unique_nodes, unique_messages,
                                      timestamps=unique_timestamps)
get_raw_messages
def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,
                     destination_node_embedding, edge_times, edge_idxs):
    # edge_times shape is [200, ]
    edge_times = torch.from_numpy(edge_times).float().to(self.device)
    # edge_features shape is [200, 172]
    edge_features = self.edge_raw_features[edge_idxs]

    source_memory = self.memory.get_memory(source_nodes) if not \
        self.use_source_embedding_in_message else source_node_embedding
    destination_memory = self.memory.get_memory(destination_nodes) if \
        not self.use_destination_embedding_in_message else destination_node_embedding

    source_time_delta = edge_times - self.memory.last_update[source_nodes]
    # source_time_delta_encoding [200, 172]
    source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(
        source_nodes), -1)
	# source_message shape [200, 688]
    source_message = torch.cat([source_memory, destination_memory, edge_features,
                                source_time_delta_encoding],
                               dim=1)
    messages = defaultdict(list)
    unique_sources = np.unique(source_nodes)

    for i in range(len(source_nodes)):
        messages[source_nodes[i]].append((source_message[i], edge_times[i]))

    return unique_sources, messages
get_updated_memory
def get_updated_memory(self, nodes, messages):
    # Aggregate messages for the same nodes
    # nodes 是一个列表 range(n_nodes)
    # messages是消息列表
    # 先是聚合消息,然后更新记忆
    # 在第一次进来这个函数的时候,返回的全是[]
    unique_nodes, unique_messages, unique_timestamps = \
        self.message_aggregator.aggregate(
            nodes, # 是一个列表 range(n_nodes)
            messages # 是消息列表
    )

    if len(unique_nodes) > 0:
        # 有两个选择
        """
        class MLPMessageFunction(MessageFunction):
    def __init__(self, raw_message_dimension, message_dimension):
        super(MLPMessageFunction, self).__init__()
        self.mlp = self.layers = nn.Sequential(
            nn.Linear(raw_message_dimension, raw_message_dimension // 2),
            nn.ReLU(),
            nn.Linear(raw_message_dimension // 2, message_dimension),
        )

    def compute_message(self, raw_messages):
        messages = self.mlp(raw_messages)
        return messages

    class IdentityMessageFunction(MessageFunction): 
        def compute_message(self, raw_messages):# 作者使用的是这个,啥也没有边,直接返回
            return raw_messages
        """
        unique_messages = self.message_function.compute_message(unique_messages)
	
    # 在头一次训练的过程中进来这个地方, 返回的全是0的矩阵
    # 形状为,[n_nodes, memory_dimension] [n_nodes]
    updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,
                                                                                 unique_messages,
                                                                               timestamps=unique_timestamps)

    return updated_memory, updated_last_update
self.message_aggregator.aggregate

代码中默认使用last

def get_message_aggregator(aggregator_type, device):
    if aggregator_type == "last":
        return LastMessageAggregator(device=device)
    elif aggregator_type == "mean":
        return MeanMessageAggregator(device=device)
    else:
        raise ValueError("Message aggregator {} not implemented".format(aggregator_type))

LastMessageAggregator代码:

class LastMessageAggregator(MessageAggregator):
    def __init__(self, device):
        super(LastMessageAggregator, self).__init__(device)

    def aggregate(self, node_ids, messages):
        """Only keep the last message for each node"""
        unique_node_ids = np.unique(node_ids) # 去重节点,不知道啥作用,因为本来就没有重复
        unique_messages = []
        unique_timestamps = []

        to_update_node_ids = []

        for node_id in unique_node_ids: # 循环range(n_nodes)=9228
            if len(messages[node_id]) > 0:
                """
                上一步结束每个节点存储的信息以及对应的(时间?)
                source_message = torch.cat([source_memory, destination_memory, edge_features,
                                    source_time_delta_encoding], dim=1)
                source_message, edge_times
                """
                to_update_node_ids.append(node_id)
                unique_messages.append(messages[node_id][-1][0])
                unique_timestamps.append(messages[node_id][-1][1])

        unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
        unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []

        return to_update_node_ids, unique_messages, unique_timestamps
self.memory_updater.get_updated_memory

代码中默认采用使用gru的方式去更新记忆

class SequenceMemoryUpdater(MemoryUpdater):
    def __init__(self, memory, message_dimension, memory_dimension, device):
        super(SequenceMemoryUpdater, self).__init__()
        self.memory = memory
        self.layer_norm = torch.nn.LayerNorm(memory_dimension)
        self.message_dimension = message_dimension
        self.device = device

    def update_memory(self, unique_node_ids, unique_messages, timestamps):
        if len(unique_node_ids) <= 0:
            return

        assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
                                                                                          "update memory to time in the past"

        memory = self.memory.get_memory(unique_node_ids)
        self.memory.last_update[unique_node_ids] = timestamps

        updated_memory = self.memory_updater(unique_messages, memory)

        self.memory.set_memory(unique_node_ids, updated_memory)

    def get_updated_memory(self, unique_node_ids, unique_messages, timestamps):
        if len(unique_node_ids) <= 0:
            # 这里的self.memory在下面进行定义
            # self.memory.memory 在初始化的时候是一个全为0,shape=[n_nodes, memory_dimension], 没有梯度的矩阵
            # self.memory.last_update 在初始化的时候是一个全为0,shape=[n_nodes], 没有梯度的举证
            # 这里的clone是深拷贝,并不会影响原来的值是多少
            
            # 第二次就不是走这里咯
            return self.memory.memory.data.clone(), self.memory.last_update.data.clone()

        assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
                                                                                          "update memory to time in the past"

        updated_memory = self.memory.memory.data.clone()
        updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids])

        updated_last_update = self.memory.last_update.data.clone()
        updated_last_update[unique_node_ids] = timestamps

        return updated_memory, updated_last_update


class GRUMemoryUpdater(SequenceMemoryUpdater):
    def __init__(self, memory, message_dimension, memory_dimension, device):
        super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)

        self.memory_updater = nn.GRUCell(input_size=message_dimension,
                                         hidden_size=memory_dimension)
Memory
class Memory(nn.Module):

    def __init__(self, n_nodes, memory_dimension, input_dimension, message_dimension=None,
                 device="cpu", combination_method='sum'):
        super(Memory, self).__init__()
        self.n_nodes = n_nodes
        self.memory_dimension = memory_dimension
        self.input_dimension = input_dimension
        self.message_dimension = message_dimension
        self.device = device

        self.combination_method = combination_method

        self.__init_memory__()
	
    # 这是初是化
    def __init_memory__(self):
        """
        Initializes the memory to all zeros. It should be called at the start of each epoch.
        """
        # Treat memory as parameter so that it is saved and loaded together with the model
        # self.memory_dimension = 172
        # self.n_nodes = 9228
        # self.memory shape is [9228, 172]的一个记忆,每一个节点都有对应的记忆,并且每一个记忆向量是172
        # self.memory = 一个全为0的矩阵
        self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),
                                   requires_grad=False)

        # last_update shape = [9228]
        self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device),
                                        requires_grad=False)

        self.messages = defaultdict(list)

    def store_raw_messages(self, nodes, node_id_to_messages):
        for node in nodes:
            self.messages[node].extend(node_id_to_messages[node])

    def get_memory(self, node_idxs):
        return self.memory[node_idxs, :]

    def set_memory(self, node_idxs, values):
        self.memory[node_idxs, :] = values

    def get_last_update(self, node_idxs):
        return self.last_update[node_idxs]

    def backup_memory(self):
        messages_clone = {}
        for k, v in self.messages.items():
            messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v]

        return self.memory.data.clone(), self.last_update.data.clone(), messages_clone

    def restore_memory(self, memory_backup):
        self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone()

        self.messages = defaultdict(list)
        for k, v in memory_backup[2].items():
            self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v]

    def detach_memory(self):
        self.memory.detach_()

        # Detach all stored messages
        for k, v in self.messages.items():
            new_node_messages = []
            for message in v:
                new_node_messages.append((message[0].detach(), message[1]))

            self.messages[k] = new_node_messages

    def clear_messages(self, nodes):
        for node in nodes:
            self.messages[node] = []
get_embedding_module

这里的module_type=graph_attention

def get_embedding_module(module_type, node_features, edge_features, memory, neighbor_finder,
                         time_encoder, n_layers, n_node_features, n_edge_features, n_time_features,
                         embedding_dimension, device,
                         n_heads=2, dropout=0.1, n_neighbors=None,
                         use_memory=True):
    # embedding_module采用的是这个
    if module_type == "graph_attention":
        return GraphAttentionEmbedding(node_features=node_features,
                                       edge_features=edge_features,
                                       memory=memory,
                                       neighbor_finder=neighbor_finder,
                                       time_encoder=time_encoder,
                                       n_layers=n_layers,
                                       n_node_features=n_node_features,
                                       n_edge_features=n_edge_features,
                                       n_time_features=n_time_features,
                                       embedding_dimension=embedding_dimension,
                                       device=device,
                                       n_heads=n_heads, dropout=dropout, use_memory=use_memory)
    elif module_type == "graph_sum":
        return GraphSumEmbedding(node_features=node_features,
                                 edge_features=edge_features,
                                 memory=memory,
                                 neighbor_finder=neighbor_finder,
                                 time_encoder=time_encoder,
                                 n_layers=n_layers,
                                 n_node_features=n_node_features,
                                 n_edge_features=n_edge_features,
                                 n_time_features=n_time_features,
                                 embedding_dimension=embedding_dimension,
                                 device=device,
                                 n_heads=n_heads, dropout=dropout, use_memory=use_memory)

    elif module_type == "identity":
        return IdentityEmbedding(node_features=node_features,
                                 edge_features=edge_features,
                                 memory=memory,
                                 neighbor_finder=neighbor_finder,
                                 time_encoder=time_encoder,
                                 n_layers=n_layers,
                                 n_node_features=n_node_features,
                                 n_edge_features=n_edge_features,
                                 n_time_features=n_time_features,
                                 embedding_dimension=embedding_dimension,
                                 device=device,
                                 dropout=dropout)
    elif module_type == "time":
        return TimeEmbedding(node_features=node_features,
                             edge_features=edge_features,
                             memory=memory,
                             neighbor_finder=neighbor_finder,
                             time_encoder=time_encoder,
                             n_layers=n_layers,
                             n_node_features=n_node_features,
                             n_edge_features=n_edge_features,
                             n_time_features=n_time_features,
                             embedding_dimension=embedding_dimension,
                             device=device,
                             dropout=dropout,
                             n_neighbors=n_neighbors)
    else:
        raise ValueError("Embedding Module {} not supported".format(module_type))
GraphAttentionEmbedding
class GraphEmbedding(EmbeddingModule):
    def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
                 n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
                 n_heads=2, dropout=0.1, use_memory=True):
        super(GraphEmbedding, self).__init__(node_features, edge_features, memory,
                                             neighbor_finder, time_encoder, n_layers,
                                             n_node_features, n_edge_features, n_time_features,
                                             embedding_dimension, device, dropout)

        self.use_memory = use_memory
        self.device = device

    def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                          use_time_proj=True):
        """Recursive implementation of curr_layers temporal graph attention layers.
		使用递归的方式来实现一系列时间图注意力
        src_idx_l [batch_size]: users / items input ids.
        cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation.
        curr_layers [scalar]: number of temporal convolutional layers to stack.
        num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer.
        """
		"""
            memory 记忆对象
            source_nodes 是一个结合了源节点目的节点和负采样节点的node_id列表(一开始是,后面不是)
            timestamps 200*3的时间列表
            self.n_layers 递归的层数 这里为2
            n_neighbors 选取多少个邻居节点 这里是10
            time_diffs 标准化过后的时间差
        """
        assert (n_layers >= 0)
		# source_nodes_torch shape = [n_nodes]
        source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)
        # timestamps_torch shape = [3*200, 1]
        timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1)

        # query node always has the start time -> time span == 0
        # 这里的time_encoder是一个模型,经过的是一个cos(linear(x)),在下面有对应的代码
        # torch.zeros_like(timestamps_torch) 是一个全为0 shape = [3*200, 1]
        # source_nodes_time_embedding shape = [3*200, 1, 172]
        source_nodes_time_embedding = self.time_encoder(torch.zeros_like(timestamps_torch))
		# self.node_features是一个全为0的矩阵
        # self.node_features shape is [n_nodes, node_dim] = [9228, 172]
        # source_node_features 是所有节点的特征 shape is [600, 172]
        source_node_features = self.node_features[source_nodes_torch, :]

        if self.use_memory:
            # 将节点当前的特征 再加上记忆中节点的特征
            source_node_features = memory[source_nodes, :] + source_node_features
		
        # ====================================== 这下面执行了一个递归的操作 ==================================
        # n_layers = 1
        if n_layers == 0:
            return source_node_features
        else:
			# 再一次调用自己,返回的是节点的特征shape is [600, 172]
            source_node_conv_embeddings = self.compute_embedding(memory,
                                                                 source_nodes,
                                                                 timestamps,
                                                                 n_layers=n_layers - 1,
                                                                 n_neighbors=n_neighbors)
			# 获得是source_nodes这3*200个节点,在3*200的时间列表中,选取前十个邻居
            """
            neighbors shape is [3*200, n_neighbors]
            edge_idxs shape is [3*200, n_neighbors]
            edge_times shape is [3*200, n_neighbors] 
            """
            neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(
                source_nodes,
                timestamps,
                n_neighbors=n_neighbors)
			
            # 这里的邻居节点node_id是source_nodes中的每一个邻居节点,变成torch形式
            neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)
			
            edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)
			
            # 时间差,600个节点的
            edge_deltas = timestamps[:, np.newaxis] - edge_times
            edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)
			
            # 展平,变成6000
            neighbors = neighbors.flatten()
			# 这是neighbor_embeddings shape = [600*10, 172]
            neighbor_embeddings = self.compute_embedding(memory,
                                                         neighbors, # 这里有6000个
                                                         np.repeat(timestamps, n_neighbors), # 也是6000
                                                         n_layers=n_layers - 1,
                                                         n_neighbors=n_neighbors)

            effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1
			# 这是neighbor_embeddings shape = [600, 10, 172]
            neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1)
            # edge_time_embeddings shape is [600, 10, 172]
            edge_time_embeddings = self.time_encoder(edge_deltas_torch)
			
            # self.edge_features shape [157475, 172]
            # edge_idxs shape [600, 10]
            # edge_features shape [600, 10, 172]
            edge_features = self.edge_features[edge_idxs, :]

            mask = neighbors_torch == 0
			
            # 这个聚合在下面
            """
            n_layers: 1
            source_node_conv_embeddings: 一开始那600个节点的编码
            source_nodes_time_embedding: 数据是和timestamps_torch一样的0矩阵[3*200, 1, 172]
            neighbor_embeddings: 之前那600个节点的发生过操作的邻居
            edge_time_embeddings: 时间差编码
            edge_features: 一开始那600个节点,对应的十个邻居,分别边的特征是多少
            mask = [600*10]
            """
            source_embedding = self.aggregate(n_layers, source_node_conv_embeddings,
                                              source_nodes_time_embedding,
                                              neighbor_embeddings,
                                              edge_time_embeddings,
                                              edge_features,
                                              mask)

            return source_embedding

    def aggregate(self, n_layers, source_node_features, source_nodes_time_embedding,
                  neighbor_embeddings,
                  edge_time_embeddings, edge_features, mask):
        return NotImplemented


class GraphAttentionEmbedding(GraphEmbedding):
    def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
                 n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
                 n_heads=2, dropout=0.1, use_memory=True):
        super(GraphAttentionEmbedding, self).__init__(node_features, edge_features, memory,
                                                      neighbor_finder, time_encoder, n_layers,
                                                      n_node_features, n_edge_features,
                                                      n_time_features,
                                                      embedding_dimension, device,
                                                      n_heads, dropout,
                                                      use_memory)

        self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer(
            n_node_features=n_node_features,
            n_neighbors_features=n_node_features,
            n_edge_features=n_edge_features,
            time_dim=n_time_features,
            n_head=n_heads,
            dropout=dropout,
            output_dimension=n_node_features)
            for _ in range(n_layers)])

    def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
                  neighbor_embeddings,
                  edge_time_embeddings, edge_features, mask):
        attention_model = self.attention_models[n_layer - 1]

        source_embedding, _ = attention_model(source_node_features,
                                              source_nodes_time_embedding,
                                              neighbor_embeddings,
                                              edge_time_embeddings,
                                              edge_features,
                                              mask)

        return source_embedding
TimeEncode
class TimeEncode(torch.nn.Module):
    # Time Encoding proposed by TGAT
    def __init__(self, dimension):
        super(TimeEncode, self).__init__()

        self.dimension = dimension # 172
        self.w = torch.nn.Linear(1, dimension)

        # todo
        self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
                                           .float().reshape(dimension, -1))
        self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())

    def forward(self, t): # -> [batch_size, seq_len, dimension]
        # t has shape [batch_size, seq_len]
        # Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]
        t = t.unsqueeze(dim=2)

        # output has shape [batch_size, seq_len, dimension]
        output = torch.cos(self.w(t))

        return output
NeighborFinder
class NeighborFinder:
    def __init__(self, adj_list, uniform=False, seed=None):
        self.node_to_neighbors = []
        self.node_to_edge_idxs = []
        self.node_to_edge_timestamps = []

        for neighbors in adj_list:
            # Neighbors is a list of tuples (neighbor, edge_idx, timestamp)
            # We sort the list based on timestamp
            sorted_neighhbors = sorted(neighbors, key=lambda x: x[2])
            self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors]))  # 是一个二维数组,第一个维度表示的是某一个节点,第二个维度表示的是这个节点和那些节点发生的联系
            self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors]))
            self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors]))

        self.uniform = uniform

        if seed is not None:
            self.seed = seed
            self.random_state = np.random.RandomState(self.seed)

    def find_before(self, src_idx, cut_time):
        """
        Extracts all the interactions happening before cut_time for user src_idx in the overall interaction graph. The returned interactions are sorted by time.

        Returns 3 lists: neighbors, edge_idxs, timestamps

        """
        i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time)

        return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[
                                                                                             src_idx][:i]

    def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20):
        """
        Given a list of users ids and relative cut times, extracts a sampled temporal neighborhood of each user in the list.

        Params
        ------
        src_idx_l: List[int]
        cut_time_l: List[float],
        num_neighbors: int
        """
        assert (len(source_nodes) == len(timestamps))

        tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1
        # NB! All interactions described in these matrices are sorted in each row by time
        neighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype( # shape [600, 10]
            np.int32)  # each entry in position (i,j) represent the id of the item targeted by user src_idx_l[i] with an interaction happening before cut_time_l[i]
        edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
            np.float32)  # each entry in position (i,j) represent the timestamp of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
        edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
            np.int32)  # each entry in position (i,j) represent the interaction index of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]

        for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)):
            source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node,
                                                                                     timestamp)  # extracts all neighbors, interactions indexes and timestamps of all interactions of user source_node happening before cut_time

            if len(source_neighbors) > 0 and n_neighbors > 0:
                if self.uniform:  # if we are applying uniform sampling, shuffles the data above before sampling
                    sampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors)

                    neighbors[i, :] = source_neighbors[sampled_idx]
                    edge_times[i, :] = source_edge_times[sampled_idx]
                    edge_idxs[i, :] = source_edge_idxs[sampled_idx]

                    # re-sort based on time
                    pos = edge_times[i, :].argsort()
                    neighbors[i, :] = neighbors[i, :][pos]
                    edge_times[i, :] = edge_times[i, :][pos]
                    edge_idxs[i, :] = edge_idxs[i, :][pos]
                else:
                    # Take most recent interactions
                    source_edge_times = source_edge_times[-n_neighbors:]
                    source_neighbors = source_neighbors[-n_neighbors:]
                    source_edge_idxs = source_edge_idxs[-n_neighbors:]

                    assert (len(source_neighbors) <= n_neighbors)
                    assert (len(source_edge_times) <= n_neighbors)
                    assert (len(source_edge_idxs) <= n_neighbors)

                    neighbors[i, n_neighbors - len(source_neighbors):] = source_neighbors
                    edge_times[i, n_neighbors - len(source_edge_times):] = source_edge_times
                    edge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxs

        return neighbors, edge_idxs, edge_times
class TemporalAttentionLayer(torch.nn.Module):
  """
  Temporal attention layer. Return the temporal embedding of a node given the node itself,
   its neighbors and the edge timestamps.
  """

  def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim,
               output_dimension, n_head=2,
               dropout=0.1):
    super(TemporalAttentionLayer, self).__init__()

    self.n_head = n_head

    self.feat_dim = n_node_features
    self.time_dim = time_dim

    self.query_dim = n_node_features + time_dim
    self.key_dim = n_neighbors_features + time_dim + n_edge_features

    self.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension)

    self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim,
                                                   kdim=self.key_dim,
                                                   vdim=self.key_dim,
                                                   num_heads=n_head,
                                                   dropout=dropout)

  def forward(self, src_node_features, src_time_features, neighbors_features,
              neighbors_time_features, edge_features, neighbors_padding_mask):
    """
    "Temporal attention model
    :param src_node_features: float Tensor of shape [batch_size, n_node_features]
    :param src_time_features: float Tensor of shape [batch_size, 1, time_dim]
    :param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features]
    :param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors,
    time_dim]
    :param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features]
    :param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors]
    :return:
    attn_output: float Tensor of shape [1, batch_size, n_node_features]
    attn_output_weights: [batch_size, 1, n_neighbors]
    """
	
    # src_node_features_unrolled shape is [600, 1, 172]
    src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1)
	
    # 将节点特征和时间特征结合在一起
    # query shape is [600, 1, 172*2]
    query = torch.cat([src_node_features_unrolled, src_time_features], dim=2)
    # 邻居的特征、边的特征和时间差特征组合在一起 key shape = [600, 10, 516]
    key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2)

    # query shape is [1, 600, 344]
    query = query.permute([1, 0, 2])  # [1, batch_size, num_of_features]
    # key shape is [10, 600, 516]
    key = key.permute([1, 0, 2])  # [n_neighbors, batch_size, num_of_features]
    # 在dim=1的维度下,要是全为True,那么就代表这一行是没有用的,反之为False
    invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True)
	# 
    neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False

    # print(query.shape, key.shape)

    attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key,
                                                              key_padding_mask=neighbors_padding_mask)

    # mask = torch.unsqueeze(neighbors_padding_mask, dim=2)  # mask [B, N, 1]
    # mask = mask.permute([0, 2, 1])
    # attn_output, attn_output_weights = self.multi_head_target(q=query, k=key, v=key,
    #                                                           mask=mask)
	
    # attn_output shape = [600, 344]
    # attn_output_weights = [600, 10]
    attn_output = attn_output.squeeze()
    attn_output_weights = attn_output_weights.squeeze()

    # Source nodes with no neighbors have an all zero attention output. The attention output is
    # then added or concatenated to the original source node features and then fed into an MLP.
    # This means that an all zero vector is not used.
    attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0)
    attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0)

    # Skip connection with temporal attention over neighborhood and the features of the node itself
    # attn_output = [600, 172]
    attn_output = self.merger(attn_output, src_node_features)

    return attn_output, attn_output_weights
MergeLayer
class MergeLayer(torch.nn.Module):
    def __init__(self, dim1, dim2, dim3, dim4):
        super().__init__()
        self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)
        self.fc2 = torch.nn.Linear(dim3, dim4)
        self.act = torch.nn.ReLU()

        torch.nn.init.xavier_normal_(self.fc1.weight)
        torch.nn.init.xavier_normal_(self.fc2.weight)

    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1)
        h = self.act(self.fc1(x))
        return self.fc2(h)

  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值