sgkt模型实验日志

  1. import numpy as np
    import scipy.sparse as sp
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import itertools
    import math
    from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
    from torch.autograd import Variable
    from layers import MLP, EraseAddGate, MLPEncoder, MLPDecoder, ScaledDotProductAttention
    from utils import gumbel_softmax
    
    class GKT(nn.Module):
    
        def __init__(self, concept_num, hidden_dim, embedding_dim, edge_type_num, graph_type, graph=None, graph_model=None, dropout=0.5, bias=True, binary=False, has_cuda=False):
            super(GKT, self).__init__()
            self.concept_num = concept_num
            self.hidden_dim = hidden_dim
            self.embedding_dim = embedding_dim
            self.edge_type_num = edge_type_num
            self.res_len = 2 if binary else 12
            self.has_cuda = has_cuda
            self.graph_type = graph_type
            self.graph = nn.Parameter(graph)  # [concept_num, concept_num]
            self.graph.requires_grad = False  # fix parameter
            self.graph_model = graph_model
    
            # one-hot feature and question
            one_hot_feat = torch.eye(self.res_len * self.concept_num)
            self.one_hot_feat = one_hot_feat.cuda() if self.has_cuda else one_hot_feat
            self.one_hot_q = torch.eye(self.concept_num, device=self.one_hot_feat.device)
            zero_padding = torch.zeros(1, self.concept_num, device=self.one_hot_feat.device)
            self.one_hot_q = torch.cat((self.one_hot_q, zero_padding), dim=0)
            # concept and concept & response embeddings
            self.emb_x = nn.Embedding(self.res_len * concept_num, embedding_dim)
            # last embedding is used for padding, so dim + 1
            self.emb_c = nn.Embedding(concept_num + 1, embedding_dim, padding_idx=-1)
    
            # f_self function and f_neighbor functions
            mlp_input_dim = hidden_dim + embedding_dim
            self.f_self = MLP(mlp_input_dim, hidden_dim, hidden_dim, dropout=dropout, bias=bias)
            self.f_neighbor_list = nn.ModuleList()
            # f_in and f_out functions
            self.f_neighbor_list.append(MLP(2 * mlp_input_dim, hidden_dim, hidden_dim, dropout=dropout, bias=bias))
            self.f_neighbor_list.append(MLP(2 * mlp_input_dim, hidden_dim, hidden_dim, dropout=dropout, bias=bias))
            # Erase & Add Gate
            self.erase_add_gate = EraseAddGate(hidden_dim, concept_num)
            # Gate Recurrent Unit
            self.gru = nn.GRUCell(hidden_dim, hidden_dim, bias=bias)
            # prediction layer
            self.predict = nn.Linear(hidden_dim, 1, bias=bias)
    
        # Aggregate step, as shown in Section 3.2.1 of the paper
        def _aggregate(self, xt, qt, ht, batch_size):
            r"""
            Parameters:
                xt: input one-hot question answering features at the current timestamp
                qt: question indices for all students in a batch at the current timestamp
                ht: hidden representations of all concepts at the current timestamp
                batch_size: the size of a student batch
            Shape:
                xt: [batch_size]
                qt: [batch_size]
                ht: [batch_size, concept_num, hidden_dim]
                tmp_ht: [batch_size, concept_num, hidden_dim + embedding_dim]
            Return:
                tmp_ht: aggregation results of concept hidden knowledge state and concept(& response) embedding
            """
            qt_mask = torch.ne(qt, -1)  # [batch_size], qt != -1
            x_idx_mat = torch.arange(self.res_len * self.concept_num, device=xt.device)
            x_embedding = self.emb_x(x_idx_mat)  # [res_len * concept_num, embedding_dim]
            masked_feat = F.embedding(xt[qt_mask], self.one_hot_feat)  # [mask_num, res_len * concept_num]
            res_embedding = masked_feat.mm(x_embedding)  # [mask_num, embedding_dim]
            mask_num = res_embedding.shape[0]
    
            concept_idx_mat = self.concept_num * torch.ones((batch_size, self.concept_num), device=xt.device).long()
            concept_idx_mat[qt_mask, :] = torch.arange(self.concept_num, device=xt.device)
            concept_embedding = self.emb_c(concept_idx_mat)  # [batch_size, concept_num, embedding_dim]
    
            index_tuple = (torch.arange(mask_num, device=xt.device), qt[qt_mask].long())
            concept_embedding[qt_mask] = concept_embedding[qt_mask].index_put(index_tuple, res_embedding)
            tmp_ht = torch.cat((ht, concept_embedding), dim=-1)  # [batch_size, concept_num, hidden_dim + embedding_dim]
            return tmp_ht
    
        # GNN aggregation step, as shown in 3.3.2 Equation 1 of the paper
        def _agg_neighbors(self, tmp_ht, qt):
            r"""
            Parameters:
                tmp_ht: temporal hidden representations of all concepts after the aggregate step
                qt: question indices for all students in a batch at the current timestamp
            Shape:
                tmp_ht: [batch_size, concept_num, hidden_dim + embedding_dim]
                qt: [batch_size]
                m_next: [batch_size, concept_num, hidden_dim]
            Return:
                m_next: hidden representations of all concepts aggregating neighboring representations at the next timestamp
                concept_embedding: input of VAE (optional)
                rec_embedding: reconstructed input of VAE (optional)
                z_prob: probability distribution of latent variable z in VAE (optional)
            """
            qt_mask = torch.ne(qt, -1)  # [batch_size], qt != -1
            masked_qt = qt[qt_mask]  # [mask_num, ]
            masked_tmp_ht = tmp_ht[qt_mask]  # [mask_num, concept_num, hidden_dim + embedding_dim]
            mask_num = masked_tmp_ht.shape[0]
            self_index_tuple = (torch.arange(mask_num, device=qt.device), masked_qt.long())
            self_ht = masked_tmp_ht[self_index_tuple]  # [mask_num, hidden_dim + embedding_dim]
            self_features = self.f_self(self_ht)  # [mask_num, hidden_dim]
            expanded_self_ht = self_ht.unsqueeze(dim=1).repeat(1, self.concept_num, 1)  #[mask_num, concept_num, hidden_dim + embedding_dim]
            neigh_ht = torch.cat((expanded_self_ht, masked_tmp_ht), dim=-1)  #[mask_num, concept_num, 2 * (hidden_dim + embedding_dim)]
            adj = self.graph[masked_qt.long(), :].unsqueeze(dim=-1)  # [mask_num, concept_num, 1]
            reverse_adj = self.graph[:, masked_qt.long()].transpose(0, 1).unsqueeze(dim=-1)  # [mask_num, concept_num, 1]
            # self.f_neighbor_list[0](neigh_ht) shape: [mask_num, concept_num, hidden_dim]
            # Calculate Shapley values for each neighbor
            yt = self._predict(tmp_ht,qt)
            shapley_values = self.calculate_shapley_values(adj, reverse_adj, yt)
            shapley_values = torch.Tensor(list(shapley_values.values())).to(adj.device)  # convert to torch tensor
    
            # Sort neighbors based on Shapley values
            sorted_indices = torch.argsort(shapley_values, descending=True)
            top_k = int(0.5 * len(sorted_indices))  # Select top 50% of neighbors
    
            # Select features of top-k neighbors
            selected_neigh_features = neigh_ht[:, sorted_indices[:top_k], :self.hidden_dim]
            selected_neigh_features = self.f_neighbor_list[0](selected_neigh_features)
    
            # Update m_next with selected neighbor features
            m_next = tmp_ht[:, :, :self.hidden_dim].clone()
            m_next[qt_mask] = selected_neigh_features
            m_next[qt_mask] = m_next[qt_mask].index_put(self_index_tuple, self_features)
            return m_next
        # 第一步是进行知识节点的邻居节点集合构建
        def generate_neighbor_sets(self, adj, reverse_adj):
            neighbor_sets = []
            concept_num = adj.shape[1]
            index_neighbors_sets = {}
            for i in range (concept_num):
                neighbors = set()
                for j in range (concept_num):
                    if adj[i,j,0]>0 or reverse_adj[i,j,0]>0:
                       neighbors.add(j)
                neighbor_sets.append(neighbors)
    
            for i,neighbors in enumerate(neighbor_sets):
                index_neighbors_sets[i] = neighbors
    
            return index_neighbors_sets
        # 第二步,使用迭代的方法枚举集合中的所有可能的排列计算每个排列下的边际贡献值,得到每个知识节点的shapley值
        def calculate_shapley_values(self,adj,reverse_adj,yt):
            concept_num = adj.shape[1]
            shapley_values = torch.zeros(concept_num, device=adj.device)
            neighbor_sets = self.generate_neighbor_sets(adj,reverse_adj)
            for i in range (concept_num):
                neighbors = neighbor_sets[i]
                shapley_value = 0.0
                for r in range (1,len(neighbors)+1):
                    for perm in itertools.permutations(neighbors,r):
                        marginal_contribution = self.calculate_marginal_contribution(perm,yt)
                        shapley_value += marginal_contribution/(math.factorial(r)*math.factorial(len(neighbors)-r))
                        shapley_values[i]=shapley_value
            return shapley_values
    
        def calculate_marginal_contribution(self, perm, yt):
            marginal_contribution = 0.0
            for i, neighbor in enumerate(perm):
                perm_without_neighbor = perm[:i] + perm[i + 1:]
                y_without_neighbor = yt.clone()
                y_without_neighbor[:, neighbor] = 0.0
                marginal_contribution += (y_without_neighbor - yt).sum(dim=1)
    
            return marginal_contribution
    
        # Update step, as shown in Section 3.3.2 of the paper
        def _update(self, tmp_ht, ht, qt):
            r"""
            Parameters:
                tmp_ht: temporal hidden representations of all concepts after the aggregate step
                ht: hidden representations of all concepts at the current timestamp
                qt: question indices for all students in a batch at the current timestamp
            Shape:
                tmp_ht: [batch_size, concept_num, hidden_dim + embedding_dim]
                ht: [batch_size, concept_num, hidden_dim]
                qt: [batch_size]
                h_next: [batch_size, concept_num, hidden_dim]
            Return:
                h_next: hidden representations of all concepts at the next timestamp
            """
            qt_mask = torch.ne(qt, -1)  # [batch_size], qt != -1
            mask_num = qt_mask.nonzero().shape[0]
            # GNN Aggregation
            m_next = self._agg_neighbors(tmp_ht, qt)  # [batch_size, concept_num, hidden_dim]
            # Erase & Add Gate
            m_next[qt_mask] = self.erase_add_gate(m_next[qt_mask])  # [mask_num, concept_num, hidden_dim]
            # GRU
            h_next = m_next
            res = self.gru(m_next[qt_mask].reshape(-1, self.hidden_dim), ht[qt_mask].reshape(-1, self.hidden_dim))  # [mask_num * concept_num, hidden_num]
            index_tuple = (torch.arange(mask_num, device=qt_mask.device), )
            h_next[qt_mask] = h_next[qt_mask].index_put(index_tuple, res.reshape(-1, self.concept_num, self.hidden_dim))
            return h_next
    
        # Predict step, as shown in Section 3.3.3 of the paper
        def _predict(self, h_next, qt):
            r"""
            Parameters:
                h_next: hidden representations of all concepts at the next timestamp after the update step
                qt: question indices for all students in a batch at the current timestamp
            Shape:
                h_next: [batch_size, concept_num, hidden_dim]
                qt: [batch_size]
                y: [batch_size, concept_num]
            Return:
                y: predicted correct probability of all concepts at the next timestamp
            """
            qt_mask = torch.ne(qt, -1)  # [batch_size], qt != -1
            y = self.predict(h_next[qt_mask]).squeeze(dim=-1)  # [batch_size, concept_num]
            y[qt_mask] = torch.sigmoid(y[qt_mask])  # [batch_size, concept_num]
            return y
    
        def _get_next_pred(self, yt, q_next):
            r"""
            Parameters:
                yt: predicted correct probability of all concepts at the next timestamp
                q_next: question index matrix at the next timestamp
                batch_size: the size of a student batch
            Shape:
                y: [batch_size, concept_num]
                questions: [batch_size, seq_len]
                pred: [batch_size, ]
            Return:
                pred: predicted correct probability of the question answered at the next timestamp
            """
            next_qt = q_next
            next_qt = torch.where(next_qt != -1, next_qt, self.concept_num * torch.ones_like(next_qt, device=yt.device))
            one_hot_qt = F.embedding(next_qt.long(), self.one_hot_q)  # [batch_size, concept_num]
            # dot product between yt and one_hot_qt
            pred = (yt * one_hot_qt).sum(dim=1)  # [batch_size, ]
            return pred
    
        def forward(self, features, questions):
            r"""
            Parameters:
                features: input one-hot matrix
                questions: question index matrix
            seq_len dimension needs padding, because different students may have learning sequences with different lengths.
            Shape:
                features: [batch_size, seq_len]
                questions: [batch_size, seq_len]
                pred_res: [batch_size, seq_len - 1]
            Return:
                pred_res: the correct probability of questions answered at the next timestamp
            """
            batch_size, seq_len = features.shape
            ht = Variable(torch.zeros((batch_size, self.concept_num, self.hidden_dim), device=features.device))
            pred_list = []
    
            for i in range(seq_len):
                xt = features[:, i]  # [batch_size]
                qt = questions[:, i]  # [batch_size]
                qt_mask = torch.ne(qt, -1)  # [batch_size], next_qt != -1
                tmp_ht = self._aggregate(xt, qt, ht, batch_size)  # [batch_size, concept_num, hidden_dim + embedding_dim]
                h_next = self._update(tmp_ht, ht, qt)  # [batch_size, concept_num, hidden_dim]
                ht[qt_mask] = h_next[qt_mask]  # update new ht
                yt = self._predict(h_next, qt)  # [batch_size, concept_num]
                pred_list.append(self._get_next_pred(yt, questions[:, i + 1]))
    
            pred_res = torch.stack(pred_list, dim=1)  # [batch_size, seq_len - 1]
            return pred_res

    calculate_shapley_values函数的返回应该是torch tensor,而不是字典。在Python中,字典的使用会相对较慢,且不能在PyTorch的autograd系统中使用。需要将shapley_values字典转换为tensor并返回。

  2. 将_shapley_values和_aggregate函数移动到GKT类内部,以保证它们能被类中的其他函数正确调用。

  3. 在_agg_neighbors函数中,shapley_values是通过字典的方式返回的,但在下文中sorted_indices = torch.argsort(shapley_values, descending=True)这一行中,试图将其视为torch tensor。这将引发一个错误,需要将shapley_values转换为torch tensor。

  4. 在_calculate_shapley_values函数中,对于每个知识点的shapley值的计算,是通过遍历所有可能的邻居集合并计算他们的边际贡献来进行的,这可能会有性能问题,因为它的计算复杂度是指数级的。可能需要寻找一种更有效的方法来计算shapley值。可能考虑的方向:在训练循环外计算Shapley值。

  5. 对每个函数进行训练前的检查和输出
    import torch
    
    # 创建示例数据
    concept_num = 4
    hidden_dim = 10
    embedding_dim = 5
    edge_type_num = 2
    graph_type = 'Dense'
    features = torch.randn(2, 5)  # 示例特征数据,大小为 [batch_size, seq_len]
    questions = torch.randint(-1, 4, (2, 5))  # 示例问题数据,大小为 [batch_size, seq_len]
    
    # 创建模型实例
    model = GKT(concept_num, hidden_dim, embedding_dim, edge_type_num, graph_type)
    
    # 打印 _aggregate 函数的运算和值
    xt = features[:, 0]  # 获取第一个时间戳的特征
    qt = questions[:, 0]  # 获取第一个时间戳的问题索引
    ht = torch.zeros((2, concept_num, hidden_dim))  # 初始化隐藏状态
    tmp_ht = model._aggregate(xt, qt, ht, batch_size=2)
    print("_aggregate - tmp_ht:", tmp_ht)
    
    # 打印 _agg_neighbors 函数的运算和值
    m_next = model._agg_neighbors(tmp_ht, qt)
    print("_agg_neighbors - m_next:", m_next)
    
    # 打印 _update 函数的运算和值
    h_next = model._update(tmp_ht, ht, qt)
    print("_update - h_next:", h_next)
    
    # 打印 _predict 函数的运算和值
    y = model._predict(h_next, qt)
    print("_predict - y:", y)
    
    # 打印 _get_next_pred 函数的运算和值
    q_next = questions[:, 1]  # 获取下一个时间戳的问题索引
    pred = model._get_next_pred(y, q_next)
    print("_get_next_pred - pred:", pred)
    

    _aggregate - tmp_ht: tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  1.6744, -0.2734,  1.4854, -0.7457,  0.0790],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000, -0.4134,  0.3520,  1.7249,  0.5040, -1.0225],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000, -0.1605,  2.4645,  0.6592,  0.4405,  0.2965],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  1.3218, -1.0412,  0.5474,  0.4427,  0.1326]],
            [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]])
    _aggregate - tmp_ht: tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  0.4616, -0.2127, -0.2274,  0.3314, -0.4540],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  0.1383,  0.4809, -1.5135,  0.6111,  0.1744],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  1.8623, -0.6276,  0.7693,  0.4770,  1.4637],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000, -1.2699, -1.6375, -1.5388, -0.1766,  0.9268]],
    
            [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  0.4616, -0.2127, -0.2274,  0.3314, -0.4540],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  0.1383,  0.4809, -1.5135,  0.6111,  0.1744],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000, -1.2699, -1.6375, -1.5388, -0.1766,  0.9268],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000, -3.0488, -0.8028,  0.2315,  0.4891,  1.4089]]])
    tmp_ht shape: torch.Size([2, 4, 13])
    qt shape: torch.Size([2])
    tmp_ht dtype: torch.float32
    qt dtype: torch.int64

    在agg的函数的计算值都可以打印出来 ,打印出来后发现,是在调用agg_neighbor的时候,tmp_ht和qt形状不符合。

  6. 原模型输出是一样的,所以是调用的时候出现问题

    _aggregate - tmp_ht: 
    tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000, -1.0762, -1.6635,  0.1739,  1.1163, -0.6199],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000, -0.1450, -2.3216,  0.8230,  0.5638, -0.5859],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  0.5896, -1.3795, -1.8392, -0.2225,  0.0325],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000, -0.4415,  1.6907, -1.4238, -1.1926, -0.3198]],
            [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000,  0.5896, -1.3795, -1.8392, -0.2225,  0.0325],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000, -0.1450, -2.3216,  0.8230,  0.5638, -0.5859],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000, -0.9347,  1.9458,  1.3726,  0.7222,  1.7240],
             [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
               0.0000, -0.4415,  1.6907, -1.4238, -1.1926, -0.3198]]])
    tmp_ht shape: torch.Size([2, 4, 13])
    qt shape: torch.Size([2])
    tmp_ht dtype: torch.float32
    qt dtype: torch.int64
    _agg_neighbors - m_next: 
    (tensor([[[-103.3577,  -65.4964,  133.9530, -342.6796,  -42.5984,  -83.8122,
              -266.8670, -249.4939],
             [ 138.8846,   11.7192,   67.9875,  -78.8893, -197.2776,   50.4039,
              -124.7583,   61.5432],
             [   0.9999,   -0.9997,    0.0000,   -0.9999,   -1.0000,   -1.0000,
                 1.0000,    0.0000],
             [ -34.7644,  104.6367, -283.5654,  108.0663,   88.5537, -296.1309,
                56.2663,   61.5432]],
            [[  -0.9999,    0.9997,    0.0000,    0.9999,    1.0000,    1.0000,
                -1.0000,    0.0000],
             [-107.6568, -212.8599,   67.9881, -117.1709,   44.4904,   50.4043,
              -197.5228,   61.5438],
             [  62.4717, -186.7843, -252.6718,   53.6039,  101.8955,  -17.9916,
               203.1964, -399.3390],
             [   6.9311,  -42.5710,   15.1827,   63.6281,   88.5530,   -6.2922,
                91.5724,   61.5428]]], 
    grad_fn=<AsStridedBackward0>), None, None, None, 
    tensor([[[-1.1346e+02],
             [ 9.2065e-43],
             [-1.1346e+02],
             [ 9.2065e-43]],
            [[-1.1346e+02],
             [ 9.2065e-43],
             [-1.1346e+02],
             [ 9.2065e-43]]]), 
    tensor([[[-113.4637],
             [-113.4637],
             [-113.4637],
             [-113.4637]],
            [[-113.4637],
             [-113.4648],
             [-113.4637],
             [-113.4629]]]))
    新的调试内容:
    # 模型初始化
    concept_num = 10
    hidden_dim = 32
    embedding_dim = 64
    edge_type_num = 2
    graph_type = "Dense"
    graph = np.zeros((concept_num, concept_num))
    graph_model = None
    dropout = 0.5
    bias = True
    binary = False
    has_cuda = False
    model = GKT(concept_num, hidden_dim, embedding_dim, edge_type_num, graph_type, graph, graph_model, dropout, bias, binary, has_cuda)
    
    # 模拟输入数据
    batch_size = 5
    seq_len = 20
    features = torch.randint(0, 2, (batch_size, seq_len))  # 二值化输入
    questions = torch.randint(-1, concept_num, (batch_size, seq_len))  # 随机问题
    
    
    # 打印中间变量的形状和值
    
    # 打印 _aggregate 函数中的变量
    print("In _aggregate function:")
    tmp_ht = model._aggregate(features[:, 0], questions[:, 0], torch.zeros((batch_size, model.concept_num, model.hidden_dim)), batch_size)
    print("tmp_ht shape: ", tmp_ht.shape)
    print("tmp_ht value: ", tmp_ht)
    
    # 打印 _agg_neighbors 函数中的变量
    print("In _agg_neighbors function:")
    m_next = model._agg_neighbors(tmp_ht, questions[:, 0])
    for i, item in enumerate(m_next):
    
        print(f"m_next[{i}] value: ", item)
    
    # 打印 _update 函数中的变量
    print("In _update function:")
    h_next = model._update(tmp_ht, torch.zeros((batch_size, model.concept_num, model.hidden_dim)), questions[:, 0])
    print("h_next shape: ", h_next.shape)
    print("h_next value: ", h_next)
    
    # 打印 _predict 函数中的变量
    print("In _predict function:")
    yt = model._predict(h_next, questions[:, 0])
    print("yt shape: ", yt.shape)
    print("yt value: ", yt)
    
    # 打印 _get_next_pred 函数中的变量
    print("In _get_next_pred function:")
    pred = model._get_next_pred(yt, questions[:, 1])
    print("pred shape: ", pred.shape)
    print("pred value: ", pred)
    
    结果输出:
    In _aggregate function:
    tmp_ht shape:  torch.Size([5, 10, 96])
    tmp_ht value:  
    tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0941, -0.4533, -1.0247],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.4808,  1.2499,  0.3105],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.7841,  0.4550,  0.3469],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ...,  0.5375,  0.2919,  1.0231],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.2228,  1.8833,  1.1549],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.8485,  0.6319,  0.5914]],
            [[ 0.0000,  0.0000,  0.0000,  ...,  0.0941, -0.4533, -1.0247],
             [ 0.0000,  0.0000,  0.0000,  ..., -2.6132, -1.0803,  0.3215],
             [ 0.0000,  0.0000,  0.0000,  ...,  1.3176, -0.8746, -0.7052],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ...,  0.5375,  0.2919,  1.0231],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.2228,  1.8833,  1.1549],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.8485,  0.6319,  0.5914]],
            [[ 0.0000,  0.0000,  0.0000,  ...,  0.0941, -0.4533, -1.0247],
             [ 0.0000,  0.0000,  0.0000,  ..., -2.6132, -1.0803,  0.3215],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.7841,  0.4550,  0.3469],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ...,  1.3176, -0.8746, -0.7052],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.2228,  1.8833,  1.1549],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.8485,  0.6319,  0.5914]],
            [[ 0.0000,  0.0000,  0.0000,  ...,  0.0941, -0.4533, -1.0247],
             [ 0.0000,  0.0000,  0.0000,  ..., -2.6132, -1.0803,  0.3215],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.7841,  0.4550,  0.3469],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ...,  0.5375,  0.2919,  1.0231],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.2228,  1.8833,  1.1549],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.8485,  0.6319,  0.5914]],
            [[ 0.0000,  0.0000,  0.0000,  ...,  0.0941, -0.4533, -1.0247],
             [ 0.0000,  0.0000,  0.0000,  ..., -2.6132, -1.0803,  0.3215],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.7841,  0.4550,  0.3469],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ...,  0.5375,  0.2919,  1.0231],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.4808,  1.2499,  0.3105],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.8485,  0.6319,  0.5914]]],
           grad_fn=<CatBackward0>)
    In _agg_neighbors function:
    m_next[0] value:  
    tensor([[[ 1.2324e+32,  2.2032e+31, -3.2379e+31,  ...,  7.7117e+30,
               4.6696e+31,  2.1787e+30],
             [ 5.7271e-01,  2.7850e-01,  0.0000e+00,  ..., -7.0925e-01,
               1.0276e-01, -4.9551e-01],
             [ 9.5813e+30, -2.6708e+30, -1.8207e+30,  ..., -2.0225e+30,
              -3.1715e+30, -4.6428e+30],
             ...,
             [-6.0195e-12,  2.1256e-11,  7.5524e-12,  ..., -4.9179e-12,
              -5.0023e-12, -4.3686e-12],
             [ 1.2946e+29,  1.5078e+30,  4.0381e+29,  ...,  6.7391e+29,
              -7.6972e+29, -5.7277e+29],
             [-1.3544e+31, -1.1288e+31, -8.5551e+30,  ...,  6.5880e+30,
               1.5617e+31, -1.7933e+31]],
            [[-2.7344e+35, -4.0952e+35, -8.8878e+34,  ..., -1.7149e+35,
              -2.4635e+35, -1.5102e+35],
             [ 1.6327e+31, -4.4965e+30,  2.6258e+30,  ..., -2.6145e+30,
               1.1917e+31,  7.0686e+30],
             [ 3.9246e-01,  1.3375e+00,  0.0000e+00,  ..., -7.0925e-01,
              -1.5798e+00, -4.9551e-01],
             ...,
             [-5.3433e+37, -4.4531e+37, -3.3750e+37,  ..., -2.3277e+37,
              -4.3370e+37,  1.0412e+38],
             [ 4.0599e+35, -3.8431e+35, -4.4264e+35,  ..., -1.6095e+35,
               1.4819e+35, -1.4174e+35],
             [-9.9087e+30, -8.2615e+30, -6.2618e+30,  ..., -1.3278e+31,
              -8.6847e+30, -1.0046e+31]],
            [[-1.3277e+28, -1.1065e+28, -8.3865e+27,  ..., -2.1175e+28,
              -1.1633e+28, -1.7579e+28],
             [-6.5295e-12, -5.3621e-12,  1.8268e-11,  ...,  3.0528e-12,
              -5.7491e-12,  2.1346e-12],
             [ 4.8472e+37, -9.6165e+37, -6.2913e+37,  ...,  1.1887e+38,
              -5.7849e+37,  3.2657e+38],
             ...,
             [-8.5953e-01, -1.6182e+00,  0.0000e+00,  ...,  1.9529e+00,
               1.4658e+00, -4.9551e-01],
             [ 1.6460e+31,  2.4490e+31, -5.6706e+30,  ..., -7.9908e+30,
              -8.7314e+30, -1.3150e+31],
             [ 3.0621e+23, -5.1679e+24, -1.3154e+24,  ..., -2.5616e+24,
              -2.5818e+24, -2.1776e+24]],
            [[ 1.8190e+28,  2.3466e+28, -8.6000e+27,  ..., -2.1715e+28,
              -7.2281e+27,  1.5122e+28],
             [-4.7508e+31, -1.6147e+31, -2.8254e+31,  ..., -1.2579e+31,
              -9.1815e+28,  5.5071e+31],
             [-1.5965e+28, -1.6241e+27,  2.2543e+28,  ..., -1.0016e+28,
              -1.4384e+28, -8.8153e+27],
             ...,
             [-2.3707e+22, -9.3314e+21, -9.2592e+22,  ..., -3.3664e+22,
              -4.8358e+22,  9.8055e+22],
             [-1.8325e+32,  8.0848e+32,  9.6632e+32,  ..., -6.5985e+30,
               2.7459e+31,  6.1748e+30],
             [ 4.1620e+31, -1.3489e+31, -4.9357e+30,  ..., -1.0167e+31,
               4.8233e+30, -8.0429e+30]],
            [[-1.6744e-19, -1.0456e-20,  2.2607e-20,  ..., -1.0501e-19,
              -1.5085e-19, -9.2480e-20],
             [-8.6299e+29, -7.1921e+29, -5.4509e+29,  ...,  1.2944e+30,
              -7.5612e+29,  7.3287e+29],
             [-2.1354e+35, -1.7795e+35, -1.3487e+35,  ..., -3.4056e+35,
              -1.8710e+35, -2.8273e+35],
             ...,
             [ 2.1699e+30, -1.9961e+30,  1.2040e+31,  ..., -7.4784e+30,
               4.2361e+30, -6.6373e+30],
             [-1.4264e+00, -5.0625e-01,  0.0000e+00,  ..., -4.0362e-01,
               4.2081e-01, -4.9551e-01],
             [ 2.1760e+12, -2.3845e+12,  3.4466e+12,  ..., -8.9708e+11,
               2.3981e+12, -3.1746e+11]]], grad_fn=<AsStridedBackward0>)
    m_next[1] value:  None
    m_next[2] value:  None
    m_next[3] value:  None
    m_next[4] value:  
    tensor([[[7.1061e+31],
             [4.2964e+24],
             [4.7824e+30],
             [9.7374e+15],
             [4.7824e+30],
             [6.1126e-02],
             [4.7824e+30],
             [3.7386e-14],
             [1.6109e-19],
             [1.8888e+31]],
            [[6.6660e-33],
             [2.1470e+29],
             [6.7111e+22],
             [4.5450e+30],
             [1.8524e+28],
             [1.9519e-19],
             [7.2708e+31],
             [7.4513e+37],
             [1.8888e+31],
             [1.3821e+31]],
            [[1.8515e+28],
             [9.1041e-12],
             [6.2609e+22],
             [4.7428e+30],
             [6.2288e+22],
             [2.1974e+23],
             [8.6758e-04],
             [4.7429e+30],
             [1.3818e+31],
             [6.7724e+22]],
            [[1.8987e+28],
             [7.1061e+31],
             [4.2964e+24],
             [3.0607e+32],
             [6.6134e+19],
             [1.0724e+31],
             [4.4657e+30],
             [1.5648e+01],
             [3.0607e+32],
             [1.6533e+19]],
            [[1.0804e-32],
             [1.2034e+30],
             [2.9777e+35],
             [1.6408e+07],
             [3.0607e+32],
             [7.3904e+22],
             [1.7239e+25],
             [6.3828e+28],
             [1.4348e-19],
             [2.7530e+12]]])
    m_next[5] value:  
    tensor([[[1.0486e+31],
             [4.2964e+24],
             [2.1470e+29],
             [1.0644e+31],
             [7.1061e+31],
             [6.2609e+22],
             [4.7428e+30],
             [9.1041e-12],
             [1.2034e+30],
             [5.6542e+05]],
            [[3.1731e+35],
             [4.7824e+30],
             [6.7111e+22],
             [4.4657e+30],
             [4.2964e+24],
             [4.7428e+30],
             [3.0607e+32],
             [6.2609e+22],
             [2.9777e+35],
             [1.1614e+27]],
            [[6.6660e-33],
             [3.7386e-14],
             [7.4513e+37],
             [1.8888e+31],
             [1.5648e+01],
             [4.9651e+28],
             [3.0607e+32],
             [4.7429e+30],
             [6.3828e+28],
             [4.6165e+24]],
            [[3.7404e-14],
             [4.7824e+30],
             [1.8524e+28],
             [2.8298e+20],
             [6.6134e+19],
             [2.7259e+20],
             [1.0724e+31],
             [6.2288e+22],
             [3.0607e+32],
             [1.9205e+31]],
            [[1.9431e-19],
             [1.6109e-19],
             [1.8888e+31],
             [6.6660e-33],
             [3.0607e+32],
             [3.6002e+27],
             [7.3904e+22],
             [1.3818e+31],
             [1.4348e-19],
             [4.3701e+12]]])
    In _update function:
    # 说明在update函数这里出现了问题
    Traceback (most recent call last):
      File "D:\N\Code\anaconda3\envs\pytorch\lib\site-packages\IPython\core\interactiveshell.py", line 3460, in run_code
        exec(code_obj, self.user_global_ns, self.user_ns)
      File "<ipython-input-2-cc829fe010f7>", line 1, in <module>
        runfile('D:/N/Code/Repository/GKT/GKT-master/Prin.py', wdir='D:/N/Code/Repository/GKT/GKT-master')
      File "D:\N\Code\JetBrains\PyCharm 2022.2.1\plugins\python\helpers\pydev\_pydev_bundle\pydev_umd.py", line 198, in runfile
        pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
      File "D:\N\Code\JetBrains\PyCharm 2022.2.1\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
        exec(compile(contents+"\n", file, 'exec'), glob, loc)
      File "D:/N/Code/Repository/GKT/GKT-master/Prin.py", line 54, in <module>
        h_next = model._update(tmp_ht, torch.zeros((batch_size, model.concept_num, model.hidden_dim)), questions[:, 0])
      File "D:\N\Code\Repository\GKT\GKT-master\models.py", line 193, in _update
        m_next, concept_embedding, rec_embedding, z_prob = self._agg_neighbors(tmp_ht, qt)  # [batch_size, concept_num, hidden_dim]
    ValueError: too many values to unpack (expected 4)
    
    原模型的正常输出:
    
    In _aggregate function:
    tmp_ht shape:  torch.Size([5, 10, 96])
    tmp_ht value:  
    tensor([[[ 0.0000,  0.0000,  0.0000,  ..., -1.3486,  0.8760,  0.2503],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.5594, -0.1822, -0.8547],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.3173,  1.0626, -1.6663],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ..., -0.9511, -0.0559,  1.0936],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0359, -0.5671,  1.5460],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.8684, -2.4842, -0.1697]],
            [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
            [[ 0.0000,  0.0000,  0.0000,  ..., -0.2469, -1.3631, -1.2307],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.5594, -0.1822, -0.8547],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.3173,  1.0626, -1.6663],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ..., -0.9511, -0.0559,  1.0936],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0359, -0.5671,  1.5460],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.8684, -2.4842, -0.1697]],
            [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
            [[ 0.0000,  0.0000,  0.0000,  ..., -0.2469, -1.3631, -1.2307],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0381, -0.4959,  0.2106],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.3173,  1.0626, -1.6663],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ..., -0.9511, -0.0559,  1.0936],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0359, -0.5671,  1.5460],
             [ 0.0000,  0.0000,  0.0000,  ..., -0.8684, -2.4842, -0.1697]]],
           grad_fn=<CatBackward0>)
    In _agg_neighbors function:
    m_next[0] value:  
    tensor([[-7.0708e-01, -7.0699e-01,  0.0000e+00,  1.4142e+00,  1.4129e+00,
             -1.9984e-01,  1.4142e+00, -7.0710e-01,  1.4142e+00,  6.2329e-01,
             -9.9698e-01, -7.0703e-01,  1.1503e+00, -1.3507e+00, -1.0090e+00,
             -6.8545e-01,  0.0000e+00, -7.0710e-01,  6.7942e-01,  0.0000e+00,
             -7.0706e-01, -9.9652e-01, -8.1369e-01, -1.4119e+00, -7.0707e-01,
              1.4030e+00,  1.4140e+00, -7.0710e-01, -7.6625e-01,  1.3715e+00,
              0.0000e+00,  1.4132e+00],
            [-9.3377e+30, -4.4438e+30,  4.0269e+31,  5.2966e+31,  4.7307e+31,
             -1.7204e+31,  6.4155e+31, -2.0324e+31,  3.4213e+31,  9.8819e+30,
             -1.2913e+31,  1.0009e+30,  1.9279e+31, -3.6524e+30,  3.0597e+31,
             -1.1375e+31, -1.3153e+31, -1.1106e+31,  4.4666e+31, -1.4089e+31,
              2.0475e+31, -6.9232e+30, -8.8107e+30,  7.2887e+29, -6.8600e+30,
             -7.0865e+30, -1.5513e+31, -7.7240e+30, -1.1357e+31, -1.9777e+31,
             -2.3097e+31, -1.2965e+31],
            [ 6.6746e+19, -1.7142e+20, -1.3694e+20,  4.0207e+20, -8.6581e+19,
             -1.6775e+20,  3.6503e+20, -1.9816e+20, -7.6050e+19, -3.6486e+19,
              1.1670e+20,  1.7826e+20, -1.5121e+20, -3.5612e+19,  2.8137e+20,
             -2.2191e+20, -1.2824e+20,  3.1863e+20,  2.2312e+20,  9.4794e+19,
              4.0157e+19,  6.3880e+19,  4.0683e+19, -1.2084e+20, -1.2029e+20,
             -6.9094e+19,  1.0492e+19, -7.5310e+19, -5.4612e+18, -1.8008e+20,
             -1.6469e+20, -1.2642e+20],
            [-7.3726e+22, -8.1940e+22,  1.3152e+23, -6.4019e+22,  1.5749e+23,
             -2.7702e+22, -3.2767e+22,  2.7229e+22,  1.1503e+23, -1.6409e+22,
              2.9096e+22,  1.3772e+23,  1.7966e+22, -1.3552e+22, -5.2015e+22,
             -1.7332e+22, -4.8803e+22, -4.1206e+22,  1.4981e+23,  7.3645e+22,
             -7.3938e+22, -2.5688e+22, -3.2691e+22,  2.3442e+23, -4.5774e+22,
             -2.6294e+22, -5.7559e+22, -2.8659e+22, -4.2138e+22,  4.5219e+21,
             -5.0129e+22,  8.8478e+21],
            [-8.7262e+22,  1.0173e+22,  3.7276e+22,  2.9179e+23, -5.0749e+22,
              3.8675e+22,  5.4565e+22, -4.8569e+22, -2.1210e+22, -4.0344e+22,
              6.0943e+22, -3.7366e+22,  1.0264e+23, -5.1681e+22, -5.3926e+22,
             -5.2226e+22,  1.1072e+23, -4.4829e+22, -3.5410e+22,  1.1655e+21,
             -2.8266e+22, -1.0642e+23, -5.6012e+22, -3.2218e+22, -6.5300e+22,
             -2.0264e+22,  3.0176e+22, -1.3118e+22,  2.3413e+23, -6.9852e+22,
             -3.1340e+22, -4.0416e+22],
            [ 6.8033e+30, -1.6109e+31, -2.3774e+30, -1.0658e+31, -1.7291e+31,
             -9.6644e+30,  3.5310e+31, -1.2201e+31, -2.0484e+31,  2.6257e+31,
              2.2330e+31, -9.3779e+30, -1.4104e+31, -1.2991e+31,  2.8924e+31,
              1.8616e+31, -1.8885e+31,  2.9299e+31, -8.8921e+30, -1.0495e+31,
             -7.1990e+30, -9.6644e+30,  8.4747e+30, -8.0738e+30,  1.2322e+31,
             -5.0799e+30, -1.9082e+31, -3.2807e+30, -7.3867e+30,  3.3135e+31,
             -7.8981e+30,  5.3422e+30],
            [ 8.1123e+30, -6.2705e+31,  1.1285e+31, -4.3019e+31,  5.3195e+31,
             -7.1557e+31, -1.1452e+31, -4.9246e+31, -8.2678e+31, -4.0961e+31,
              1.0570e+28, -3.7851e+31,  8.4043e+31, -5.2434e+31,  5.9398e+31,
             -5.3062e+31, -7.6225e+31,  1.8598e+31, -3.5890e+31, -4.2359e+31,
             -2.9057e+31, -5.8872e+29, -5.6780e+31, -3.2588e+31, -6.6175e+31,
             -2.0504e+31, -2.1027e+31, -1.3242e+31, -2.9815e+31,  5.7247e+31,
             -3.1878e+31, -4.0905e+31],
            [ 1.4829e-19,  1.0967e-19,  7.7588e-19,  7.9435e-20, -2.2232e-21,
             -9.9126e-20,  4.5317e-19,  9.7788e-19,  1.6251e-19,  3.0213e-19,
              7.8386e-20, -9.8366e-20,  4.6539e-19, -1.8116e-19,  6.1422e-20,
              6.9460e-19, -1.0671e-19, -2.1518e-19, -1.2380e-19, -2.2600e-19,
             -2.2842e-19, -3.6877e-19, -1.7340e-19, -1.8498e-19,  1.8190e-19,
             -1.1130e-19,  5.2681e-21, -9.4388e-20,  7.1995e-19,  6.4983e-20,
             -5.4982e-20,  1.8741e-21],
            [-3.7241e+31, -3.3337e+32, -2.3286e+32, -2.8151e+32,  1.0224e+33,
              6.4429e+31, -1.3606e+32, -5.9563e+31,  5.3538e+31,  1.4440e+32,
             -1.9856e+32, -1.4239e+32, -2.4505e+32, -6.3223e+31, -2.3303e+32,
             -2.3442e+30, -2.1750e+32, -1.8454e+32, -1.9125e+32,  5.1497e+32,
             -1.0966e+32, -1.2042e+32, -1.4346e+32, -2.0468e+32, -1.4783e+32,
             -1.8087e+31, -1.2408e+32,  4.9439e+32, -1.8298e+32, -3.4235e+31,
             -1.4410e+32, -2.1459e+32],
            [-4.8799e+34, -7.5724e+35, -5.8648e+35, -5.0059e+35,  5.1855e+34,
              5.8840e+35, -2.6729e+35, -5.7328e+35, -5.1632e+32, -4.7714e+35,
              4.9774e+35,  1.4959e+36, -9.3555e+35,  1.1309e+36,  1.0021e+36,
             -6.1821e+35,  8.0353e+35,  1.6282e+35, -4.1733e+35,  1.0937e+36,
             -3.3849e+35,  8.5451e+35,  1.5014e+36, -3.7940e+35, -7.7074e+35,
             -2.3773e+35, -1.3762e+36, -1.5430e+35, -3.4728e+35,  1.4207e+36,
             -3.7155e+35, -4.7650e+35]], grad_fn=<UnbindBackward0>)
    m_next[1] value:  
    tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<UnbindBackward0>)
    m_next[2] value:  
    tensor([[ 3.2706e+22,  6.1654e+22, -1.0289e+22,  2.6162e+20,  9.3671e+22,
             -9.1952e+21,  4.4778e+22,  9.0740e+22,  3.7636e+22, -4.0345e+22,
             -6.6226e+22, -3.7282e+22,  2.4758e+22, -5.1646e+22, -3.9467e+21,
             -3.0683e+22, -3.1847e+22, -4.5224e+22, -3.5351e+22, -4.1722e+22,
             -2.8620e+22,  4.5490e+22, -2.5942e+22,  2.8234e+22,  1.0630e+23,
             -2.0195e+22,  3.7853e+22, -1.3042e+22, -2.9366e+22,  6.0719e+22,
             -3.1399e+22, -4.0290e+22],
            [-4.5834e+36, -2.6728e+36, -1.0301e+37, -1.2655e+37,  8.3586e+36,
              1.1143e+37, -6.4773e+36, -1.4907e+37, -9.6596e+36, -8.7622e+36,
             -9.4710e+36, -6.2777e+36, -1.1375e+37, -2.6789e+36,  3.5390e+37,
             -1.6662e+37,  8.2350e+36, -8.1455e+36, -8.4770e+36,  2.9643e+36,
              1.1818e+37, -5.0779e+36, -6.4623e+36,  1.6043e+37, -3.4187e+36,
             -4.9023e+36,  1.6553e+37, -5.4145e+36,  7.5373e+36, -1.3946e+37,
             -5.9986e+35,  9.9735e+36],
            [-5.9961e+29,  6.3656e+30, -2.5111e+30, -6.2818e+29, -2.2669e+30,
             -1.6721e+30,  8.9544e+30, -3.6975e+30, -3.3621e+30, -1.9707e+30,
              1.0745e+31, -1.1151e+30, -3.8168e+30, -9.3241e+29,  9.3388e+30,
             -1.5307e+30,  2.3235e+30, -2.8351e+30, -2.9505e+30, -3.3862e+30,
             -5.0872e+30,  1.3736e+30,  1.0323e+31, -3.1638e+30, -9.5350e+29,
             -1.8091e+30, -3.4885e+30, -1.9718e+30,  1.2260e+31, -5.0489e+30,
             -3.4989e+30,  4.1591e+29],
            [-7.0708e-01, -7.0699e-01,  0.0000e+00, -7.0710e-01, -7.0646e-01,
              1.3124e+00, -7.0709e-01,  1.4142e+00, -7.0708e-01, -1.4110e+00,
              1.3671e+00, -7.0703e-01,  1.3721e-01,  1.0381e+00,  1.3627e+00,
             -7.2843e-01,  0.0000e+00,  1.4142e+00, -1.4138e+00,  0.0000e+00,
             -7.0706e-01,  1.3673e+00, -5.9481e-01,  6.3807e-01, -7.0707e-01,
             -5.4807e-01, -7.0702e-01,  1.4142e+00,  1.4125e+00, -9.8435e-01,
              0.0000e+00, -7.0660e-01],
            [-4.9704e-20,  1.9283e-19, -1.7453e-19,  8.8613e-20, -7.9543e-20,
             -2.2866e-19, -1.3696e-19,  9.6960e-21, -2.1970e-20, -8.3595e-20,
             -7.8045e-20, -1.3407e-19, -2.1583e-19, -1.1921e-19, -2.2210e-19,
             -2.1614e-20,  4.8845e-21, -1.1502e-19, -8.6453e-20, -1.8227e-19,
             -1.8023e-19, -1.7000e-19,  1.5182e-19, -1.5134e-19, -5.4573e-20,
             -9.0137e-20, -1.6925e-19, -8.0708e-20, -1.3859e-19, -8.9996e-20,
              7.3867e-20,  9.5617e-20],
            [ 4.4468e+26, -4.9025e+26, -2.3940e+27,  4.2447e+27, -2.0231e+27,
              7.5912e+27, -2.0120e+27,  3.2536e+27, -3.0005e+27, -2.7218e+27,
             -2.9420e+27, -1.5690e+27,  1.0161e+27, -8.3214e+26, -3.1939e+27,
              4.0000e+27, -1.9315e+26, -2.5302e+27, -2.6332e+27, -3.2098e+27,
             -1.7471e+27, -1.5773e+27, -2.0074e+27, -2.8236e+27,  1.1694e+28,
             -1.6145e+27, -1.2657e+27, -1.7598e+27, -2.5874e+27,  3.8709e+27,
              5.5685e+27, -2.9539e+27],
            [-1.9350e+31, -9.7497e+30, -5.1880e+30,  1.4319e+30,  2.6939e+31,
             -8.2378e+30, -8.6002e+30,  6.0057e+30, -1.2825e+31, -1.1069e+31,
              2.9593e+30, -8.3351e+30, -1.5103e+31, -3.5569e+30,  2.2506e+31,
             -4.6303e+30,  2.5249e+31, -2.4485e+30, -1.1255e+31, -2.9802e+30,
             -7.2700e+30, -6.7421e+30, -8.5802e+30, -1.0575e+31,  2.9042e+31,
             -6.9011e+30,  5.8356e+30,  1.3281e+31,  3.3284e+30, -1.9260e+31,
              4.2243e+30,  6.4633e+30],
            [ 2.2739e+28,  1.2526e+28, -1.3374e+28, -1.6430e+28, -8.4561e+27,
              5.1880e+27, -8.4096e+27, -1.3175e+28, -4.6555e+27,  1.0422e+28,
             -1.2296e+28, -8.1504e+27,  1.3708e+28, -3.4781e+27, -1.9735e+27,
              1.1638e+28,  4.6767e+28, -1.0575e+28, -1.1006e+28, -1.3416e+28,
              3.8155e+27, -6.5927e+27, -8.3901e+27, -5.7155e+27,  1.8367e+28,
              6.5183e+28, -1.1263e+28, -7.3553e+27,  3.5261e+27, -5.0872e+27,
              6.7944e+27, -1.2347e+28],
            [-1.8750e+19, -2.0840e+19, -1.3253e+19,  8.0226e+18, -8.3797e+18,
             -4.4038e+18, -8.3336e+18,  1.3502e+18, -1.2428e+19, -9.1538e+18,
             -1.2185e+19, -8.0768e+18, -6.0951e+18,  1.4082e+17, -1.3229e+19,
             -1.6047e+19, -7.9953e+17, -1.0480e+19, -1.0906e+19, -1.3295e+19,
              1.6455e+19, -6.5332e+18, -8.3143e+18, -9.4446e+18,  2.4920e+18,
             -6.6872e+18, -1.1521e+19, -7.2887e+18, -1.0717e+19, -1.7708e+19,
             -1.1502e+19, -1.2235e+19],
            [ 6.7204e+22,  1.1744e+23,  4.1223e+22, -4.2369e+22,  9.4173e+22,
             -1.9247e+22,  4.0880e+22, -2.8197e+22,  8.4356e+22, -4.0342e+22,
             -4.7019e+22, -3.7279e+22,  2.9378e+22,  9.9465e+22,  1.5250e+22,
             -5.2259e+22,  5.1191e+22, -4.5220e+22, -3.5348e+22, -4.1718e+22,
             -2.8617e+22,  1.2036e+23,  1.8225e+22, -3.2095e+22,  1.2533e+23,
             -2.0194e+22,  1.1872e+23, -1.3041e+22, -2.9364e+22,  1.4555e+23,
             -3.1396e+22,  7.3005e+22]], grad_fn=<UnbindBackward0>)
    m_next[3] value:  
    tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<UnbindBackward0>)
    m_next[4] value:  tensor([[ 2.2705e+31,  1.4660e+31, -1.2587e+31, -9.2784e+30,  1.5091e+31,
              5.1478e+30, -1.5791e+31, -8.0918e+30,  2.5169e+31, -1.0873e+31,
             -1.1210e+31, -4.5892e+30,  3.8851e+30, -2.5928e+30, -1.6804e+31,
             -6.3257e+30,  1.0203e+31, -1.2188e+31, -9.5274e+30, -1.1244e+31,
             -7.7133e+30,  2.1364e+31,  2.1834e+30,  2.7694e+31, -1.7567e+31,
             -5.4429e+30,  1.4351e+31, -3.5151e+30, -7.9145e+30,  4.2026e+30,
             -8.4624e+30, -1.0859e+31],
            [ 1.4142e+00,  1.4140e+00,  0.0000e+00, -7.0710e-01, -7.0646e-01,
             -1.1125e+00, -7.0709e-01, -7.0710e-01, -7.0708e-01,  7.8772e-01,
             -3.7009e-01,  1.4141e+00, -1.2875e+00,  3.1256e-01, -3.5368e-01,
              1.4139e+00,  0.0000e+00, -7.0710e-01,  7.3442e-01,  0.0000e+00,
              1.4141e+00, -3.7075e-01,  1.4085e+00,  7.7387e-01,  1.4141e+00,
             -8.5493e-01, -7.0702e-01, -7.0710e-01, -6.4625e-01, -3.8719e-01,
              0.0000e+00, -7.0660e-01],
            [-8.7480e+22, -6.4217e+22,  4.4588e+22, -4.2440e+22, -6.4291e+22,
             -8.7220e+21, -5.6982e+22, -4.8785e+22, -8.1393e+22, -4.0508e+22,
             -5.8175e+22, -8.1076e+21, -2.7980e+22,  4.6328e+21, -6.2220e+22,
             -5.2296e+22, -7.5255e+22, -4.4817e+22,  1.2510e+23,  8.5972e+22,
              3.0243e+22, -4.6261e+22, -5.5230e+22, -3.2242e+22,  7.4183e+22,
             -2.0292e+22,  6.4002e+22, -1.3149e+22, -2.8942e+22, -4.9838e+22,
             -3.1719e+22, -4.0467e+22],
            [-1.7277e+37, -1.2660e+37,  1.8427e+36, -7.3565e+36,  1.9082e+36,
             -8.2479e+36, -6.5535e+36,  1.2388e+36, -1.6098e+37, -7.9752e+36,
             -6.8664e+36, -7.3698e+36, -1.5649e+37, -3.7562e+36, -1.2325e+37,
             -1.0331e+37, -7.9551e+36, -8.9396e+36, -6.9880e+36,  1.2858e+37,
             -5.6574e+36, -1.3873e+37, -3.6395e+35, -4.3316e+36,  1.2969e+37,
             -3.9921e+36, -3.9011e+36, -2.5782e+36, -5.1007e+36,  1.0325e+37,
             -6.2068e+36,  4.6018e+36],
            [-2.7712e+20, -3.0800e+20, -1.9526e+20,  1.4188e+20, -1.2385e+20,
             -2.3994e+20, -1.2317e+20,  3.3179e+19, -1.8368e+20, -1.2251e+19,
              9.6322e+19, -9.1869e+19, -6.4163e+19, -5.0939e+19, -1.6825e+20,
             -6.8739e+19, -1.8344e+20, -1.5489e+20, -1.6119e+20,  3.9078e+20,
             -2.7792e+20, -9.6556e+19, -1.2288e+20, -7.7369e+19,  1.9579e+20,
              8.1454e+19, -7.2322e+19,  3.3365e+20, -2.7567e+19, -1.5337e+20,
             -5.7041e+19,  3.5530e+20],
            [ 7.5070e-20,  2.1309e-20,  5.7505e-20, -1.3723e-20, -2.3371e-19,
              2.0968e-19, -1.3830e-19, -1.4768e-19, -2.4014e-19,  1.4322e-19,
             -2.4404e-19, -1.3306e-19,  1.8781e-21, -5.3809e-20, -2.2042e-19,
              9.1289e-20,  1.1045e-19,  1.5613e-20,  3.1851e-19, -1.0209e-19,
              1.9921e-19,  3.0032e-20, -1.2337e-19,  4.1138e-19,  9.0199e-20,
              8.9611e-20,  2.1762e-19, -8.0373e-20, -1.2087e-19, -8.6872e-20,
              2.1440e-19, -2.7865e-20],
            [ 1.3431e+27,  6.1703e+27, -3.1962e+27,  2.6143e+26, -2.1683e+27,
             -1.8776e+27, -2.1564e+27, -9.0900e+26, -3.2158e+27, -2.4771e+26,
              7.8160e+27,  7.5658e+27, -3.7869e+27, -8.9185e+26, -3.4231e+27,
             -5.5573e+27, -3.2117e+27, -2.7118e+27, -2.8221e+27, -2.1575e+27,
              1.7284e+26, -1.6905e+27,  1.2424e+28, -3.0262e+27, -2.7738e+27,
             -1.7304e+27, -3.7879e+27, -1.8860e+27, -2.7731e+27,  4.3988e+27,
             -1.3221e+26, -3.1659e+27],
            [ 4.3761e+28,  1.2717e+29, -4.1877e+28, -2.9479e+28, -3.5374e+28,
             -6.8535e+28, -3.5180e+28,  9.7419e+28, -5.2463e+28,  2.4630e+29,
             -5.1439e+28, -3.4095e+28, -4.8961e+28, -1.4550e+28, -5.5844e+28,
              1.3153e+29,  2.2546e+29, -4.4240e+28, -4.6040e+28, -5.6123e+28,
             -7.9381e+28, -2.7579e+28, -3.5098e+28, -4.9369e+28, -4.9144e+28,
              2.0273e+28, -6.1796e+28, -3.0769e+28,  3.9149e+28,  5.4236e+28,
              1.4430e+29, -5.1649e+28],
            [ 3.2065e+25,  2.0080e+25,  6.7749e+24, -1.0482e+25,  1.4335e+25,
              2.7997e+24, -7.9242e+24,  9.4052e+24,  1.7527e+25,  4.9195e+25,
             -3.4844e+24, -9.2233e+24,  3.5161e+25, -1.2777e+25,  2.5923e+25,
              1.9148e+25, -9.3156e+24,  4.4959e+25,  1.2142e+25, -1.0321e+25,
             -6.3326e+24,  1.6733e+25, -1.3836e+25, -7.9407e+24, -1.3955e+25,
             -4.9962e+24,  4.0984e+24, -3.2266e+24, -7.2649e+24, -1.7253e+25,
              2.3038e+25, -9.9671e+24],
            [-7.6128e+30,  5.4335e+31,  6.2959e+31, -8.0251e+30,  2.4663e+31,
             -1.5642e+31, -3.7694e+30,  2.1301e+31,  2.1301e+31, -2.9152e+31,
             -3.9283e+31, -2.6939e+31, -9.8384e+30,  3.8949e+31,  5.1939e+31,
             -2.0039e+31,  5.2509e+31,  1.6385e+31, -2.5543e+31, -3.0147e+31,
             -2.0680e+31,  8.6417e+31, -2.2322e+30, -2.3193e+31,  7.7564e+31,
             -1.4592e+31,  2.1713e+31, -9.4241e+30, -1.4558e+30,  6.8830e+31,
             -2.2688e+31,  2.3506e+31]], grad_fn=<UnbindBackward0>)
    In _update function:
    h_next shape:  torch.Size([5, 10, 32])
    h_next value:  
    tensor([[[ 0.1248,  0.1779,  0.1606,  ..., -0.0967,  0.0597,  0.0568],
             [ 0.0207, -0.0659,  0.0029,  ...,  0.0385, -0.0248, -0.0389],
             [ 0.0420, -0.0477,  0.0094,  ...,  0.0488, -0.0487, -0.0332],
             ...,
             [ 0.0400, -0.0495,  0.0088,  ...,  0.0478, -0.0464, -0.0337],
             [ 0.0231, -0.0639,  0.0036,  ...,  0.0396, -0.0274, -0.0382],
             [ 0.0341, -0.0546,  0.0070,  ...,  0.0449, -0.0398, -0.0353]],
            [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
            [[ 0.0258, -0.0616,  0.0045,  ...,  0.0409, -0.0305, -0.0375],
             [ 0.0207, -0.0659,  0.0029,  ...,  0.0385, -0.0248, -0.0389],
             [ 0.0420, -0.0477,  0.0094,  ...,  0.0488, -0.0487, -0.0332],
             ...,
             [ 0.0400, -0.0495,  0.0088,  ...,  0.0478, -0.0464, -0.0337],
             [ 0.0231, -0.0639,  0.0036,  ...,  0.0396, -0.0274, -0.0382],
             [ 0.0341, -0.0546,  0.0070,  ...,  0.0449, -0.0398, -0.0353]],
            [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             ...,
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
            [[ 0.0258, -0.0616,  0.0045,  ...,  0.0409, -0.0305, -0.0375],
             [-0.0317, -0.2758, -0.0693,  ..., -0.0358, -0.2878, -0.1085],
             [ 0.0420, -0.0477,  0.0094,  ...,  0.0488, -0.0487, -0.0332],
             ...,
             [ 0.0400, -0.0495,  0.0088,  ...,  0.0478, -0.0464, -0.0337],
             [ 0.0231, -0.0639,  0.0036,  ...,  0.0396, -0.0274, -0.0382],
             [ 0.0341, -0.0546,  0.0070,  ...,  0.0449, -0.0398, -0.0353]]],
           grad_fn=<AsStridedBackward0>)
    In _predict function:
    yt shape:  torch.Size([5, 10])
    yt value:  
    tensor([[0.4969, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068,
             0.5068],
            [0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228,
             0.0228],
            [0.5068, 0.5068, 0.5068, 0.5441, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068,
             0.5068],
            [0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228,
             0.0228],
            [0.5068, 0.4666, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068,
             0.5068]], grad_fn=<AsStridedBackward0>)
    In _get_next_pred function:
    pred shape:  torch.Size([5])
    pred value:  tensor([0.5068, 0.0228, 0.5068, 0.0228, 0.5068], grad_fn=<SumBackward1>)
    

    发现错误:

        def generate_neighbor_sets(self, adj, reverse_adj):
            neighbor_sets = []
            concept_num = adj.shape[1]
            index_neighbors_sets = {}
            for i in range (concept_num):
                neighbors = set()
                for j in range (concept_num):
                    if adj[i,j,0]>0 or reverse_adj[i,j,0]>0:
                       neighbors.add(j)
                neighbor_sets.append(neighbors)
    
            for i,neighbors in enumerate(neighbor_sets):
                index_neighbors_sets[i] = neighbors
    
            return index_neighbors_sets
    #报错,index 5 is out of bound for dimension 0 with size 5
    #没注意到adj构建的时候
    adj = self.graph[masked_qt.long(),:].unsqueeze(dim=-1) 
            reverse_adj = self.graph[:,masked_qt.long()].transpose(0,1).unsqueeze(dim=-1) 
    #形状已经改变了,变为:[mask_num, concept_num,1]
    #所以搜索范围不再是10*10
    #修改为:    
    def generate_neighbor_sets(self, adj, reverse_adj):
            neighbor_sets = []
            mask_num = adj.shape[0]
            concept_num = adj.shape[1]
            index_neighbors_sets = {}
            for i in range (mask_num):
                neighbors = set()
                for j in range (concept_num):
                    if adj[i,j,0]>0 or reverse_adj[i,j,0]>0:
                       neighbors.add(j)
                neighbor_sets.append(neighbors)
    
            for i,neighbors in enumerate(neighbor_sets):
                index_neighbors_sets[i] = neighbors
    
            return index_neighbors_sets

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Babulalala_lyn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值