-
import numpy as np import scipy.sparse as sp import torch import torch.nn as nn import torch.nn.functional as F import itertools import math from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence from torch.autograd import Variable from layers import MLP, EraseAddGate, MLPEncoder, MLPDecoder, ScaledDotProductAttention from utils import gumbel_softmax class GKT(nn.Module): def __init__(self, concept_num, hidden_dim, embedding_dim, edge_type_num, graph_type, graph=None, graph_model=None, dropout=0.5, bias=True, binary=False, has_cuda=False): super(GKT, self).__init__() self.concept_num = concept_num self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim self.edge_type_num = edge_type_num self.res_len = 2 if binary else 12 self.has_cuda = has_cuda self.graph_type = graph_type self.graph = nn.Parameter(graph) # [concept_num, concept_num] self.graph.requires_grad = False # fix parameter self.graph_model = graph_model # one-hot feature and question one_hot_feat = torch.eye(self.res_len * self.concept_num) self.one_hot_feat = one_hot_feat.cuda() if self.has_cuda else one_hot_feat self.one_hot_q = torch.eye(self.concept_num, device=self.one_hot_feat.device) zero_padding = torch.zeros(1, self.concept_num, device=self.one_hot_feat.device) self.one_hot_q = torch.cat((self.one_hot_q, zero_padding), dim=0) # concept and concept & response embeddings self.emb_x = nn.Embedding(self.res_len * concept_num, embedding_dim) # last embedding is used for padding, so dim + 1 self.emb_c = nn.Embedding(concept_num + 1, embedding_dim, padding_idx=-1) # f_self function and f_neighbor functions mlp_input_dim = hidden_dim + embedding_dim self.f_self = MLP(mlp_input_dim, hidden_dim, hidden_dim, dropout=dropout, bias=bias) self.f_neighbor_list = nn.ModuleList() # f_in and f_out functions self.f_neighbor_list.append(MLP(2 * mlp_input_dim, hidden_dim, hidden_dim, dropout=dropout, bias=bias)) self.f_neighbor_list.append(MLP(2 * mlp_input_dim, hidden_dim, hidden_dim, dropout=dropout, bias=bias)) # Erase & Add Gate self.erase_add_gate = EraseAddGate(hidden_dim, concept_num) # Gate Recurrent Unit self.gru = nn.GRUCell(hidden_dim, hidden_dim, bias=bias) # prediction layer self.predict = nn.Linear(hidden_dim, 1, bias=bias) # Aggregate step, as shown in Section 3.2.1 of the paper def _aggregate(self, xt, qt, ht, batch_size): r""" Parameters: xt: input one-hot question answering features at the current timestamp qt: question indices for all students in a batch at the current timestamp ht: hidden representations of all concepts at the current timestamp batch_size: the size of a student batch Shape: xt: [batch_size] qt: [batch_size] ht: [batch_size, concept_num, hidden_dim] tmp_ht: [batch_size, concept_num, hidden_dim + embedding_dim] Return: tmp_ht: aggregation results of concept hidden knowledge state and concept(& response) embedding """ qt_mask = torch.ne(qt, -1) # [batch_size], qt != -1 x_idx_mat = torch.arange(self.res_len * self.concept_num, device=xt.device) x_embedding = self.emb_x(x_idx_mat) # [res_len * concept_num, embedding_dim] masked_feat = F.embedding(xt[qt_mask], self.one_hot_feat) # [mask_num, res_len * concept_num] res_embedding = masked_feat.mm(x_embedding) # [mask_num, embedding_dim] mask_num = res_embedding.shape[0] concept_idx_mat = self.concept_num * torch.ones((batch_size, self.concept_num), device=xt.device).long() concept_idx_mat[qt_mask, :] = torch.arange(self.concept_num, device=xt.device) concept_embedding = self.emb_c(concept_idx_mat) # [batch_size, concept_num, embedding_dim] index_tuple = (torch.arange(mask_num, device=xt.device), qt[qt_mask].long()) concept_embedding[qt_mask] = concept_embedding[qt_mask].index_put(index_tuple, res_embedding) tmp_ht = torch.cat((ht, concept_embedding), dim=-1) # [batch_size, concept_num, hidden_dim + embedding_dim] return tmp_ht # GNN aggregation step, as shown in 3.3.2 Equation 1 of the paper def _agg_neighbors(self, tmp_ht, qt): r""" Parameters: tmp_ht: temporal hidden representations of all concepts after the aggregate step qt: question indices for all students in a batch at the current timestamp Shape: tmp_ht: [batch_size, concept_num, hidden_dim + embedding_dim] qt: [batch_size] m_next: [batch_size, concept_num, hidden_dim] Return: m_next: hidden representations of all concepts aggregating neighboring representations at the next timestamp concept_embedding: input of VAE (optional) rec_embedding: reconstructed input of VAE (optional) z_prob: probability distribution of latent variable z in VAE (optional) """ qt_mask = torch.ne(qt, -1) # [batch_size], qt != -1 masked_qt = qt[qt_mask] # [mask_num, ] masked_tmp_ht = tmp_ht[qt_mask] # [mask_num, concept_num, hidden_dim + embedding_dim] mask_num = masked_tmp_ht.shape[0] self_index_tuple = (torch.arange(mask_num, device=qt.device), masked_qt.long()) self_ht = masked_tmp_ht[self_index_tuple] # [mask_num, hidden_dim + embedding_dim] self_features = self.f_self(self_ht) # [mask_num, hidden_dim] expanded_self_ht = self_ht.unsqueeze(dim=1).repeat(1, self.concept_num, 1) #[mask_num, concept_num, hidden_dim + embedding_dim] neigh_ht = torch.cat((expanded_self_ht, masked_tmp_ht), dim=-1) #[mask_num, concept_num, 2 * (hidden_dim + embedding_dim)] adj = self.graph[masked_qt.long(), :].unsqueeze(dim=-1) # [mask_num, concept_num, 1] reverse_adj = self.graph[:, masked_qt.long()].transpose(0, 1).unsqueeze(dim=-1) # [mask_num, concept_num, 1] # self.f_neighbor_list[0](neigh_ht) shape: [mask_num, concept_num, hidden_dim] # Calculate Shapley values for each neighbor yt = self._predict(tmp_ht,qt) shapley_values = self.calculate_shapley_values(adj, reverse_adj, yt) shapley_values = torch.Tensor(list(shapley_values.values())).to(adj.device) # convert to torch tensor # Sort neighbors based on Shapley values sorted_indices = torch.argsort(shapley_values, descending=True) top_k = int(0.5 * len(sorted_indices)) # Select top 50% of neighbors # Select features of top-k neighbors selected_neigh_features = neigh_ht[:, sorted_indices[:top_k], :self.hidden_dim] selected_neigh_features = self.f_neighbor_list[0](selected_neigh_features) # Update m_next with selected neighbor features m_next = tmp_ht[:, :, :self.hidden_dim].clone() m_next[qt_mask] = selected_neigh_features m_next[qt_mask] = m_next[qt_mask].index_put(self_index_tuple, self_features) return m_next # 第一步是进行知识节点的邻居节点集合构建 def generate_neighbor_sets(self, adj, reverse_adj): neighbor_sets = [] concept_num = adj.shape[1] index_neighbors_sets = {} for i in range (concept_num): neighbors = set() for j in range (concept_num): if adj[i,j,0]>0 or reverse_adj[i,j,0]>0: neighbors.add(j) neighbor_sets.append(neighbors) for i,neighbors in enumerate(neighbor_sets): index_neighbors_sets[i] = neighbors return index_neighbors_sets # 第二步,使用迭代的方法枚举集合中的所有可能的排列计算每个排列下的边际贡献值,得到每个知识节点的shapley值 def calculate_shapley_values(self,adj,reverse_adj,yt): concept_num = adj.shape[1] shapley_values = torch.zeros(concept_num, device=adj.device) neighbor_sets = self.generate_neighbor_sets(adj,reverse_adj) for i in range (concept_num): neighbors = neighbor_sets[i] shapley_value = 0.0 for r in range (1,len(neighbors)+1): for perm in itertools.permutations(neighbors,r): marginal_contribution = self.calculate_marginal_contribution(perm,yt) shapley_value += marginal_contribution/(math.factorial(r)*math.factorial(len(neighbors)-r)) shapley_values[i]=shapley_value return shapley_values def calculate_marginal_contribution(self, perm, yt): marginal_contribution = 0.0 for i, neighbor in enumerate(perm): perm_without_neighbor = perm[:i] + perm[i + 1:] y_without_neighbor = yt.clone() y_without_neighbor[:, neighbor] = 0.0 marginal_contribution += (y_without_neighbor - yt).sum(dim=1) return marginal_contribution # Update step, as shown in Section 3.3.2 of the paper def _update(self, tmp_ht, ht, qt): r""" Parameters: tmp_ht: temporal hidden representations of all concepts after the aggregate step ht: hidden representations of all concepts at the current timestamp qt: question indices for all students in a batch at the current timestamp Shape: tmp_ht: [batch_size, concept_num, hidden_dim + embedding_dim] ht: [batch_size, concept_num, hidden_dim] qt: [batch_size] h_next: [batch_size, concept_num, hidden_dim] Return: h_next: hidden representations of all concepts at the next timestamp """ qt_mask = torch.ne(qt, -1) # [batch_size], qt != -1 mask_num = qt_mask.nonzero().shape[0] # GNN Aggregation m_next = self._agg_neighbors(tmp_ht, qt) # [batch_size, concept_num, hidden_dim] # Erase & Add Gate m_next[qt_mask] = self.erase_add_gate(m_next[qt_mask]) # [mask_num, concept_num, hidden_dim] # GRU h_next = m_next res = self.gru(m_next[qt_mask].reshape(-1, self.hidden_dim), ht[qt_mask].reshape(-1, self.hidden_dim)) # [mask_num * concept_num, hidden_num] index_tuple = (torch.arange(mask_num, device=qt_mask.device), ) h_next[qt_mask] = h_next[qt_mask].index_put(index_tuple, res.reshape(-1, self.concept_num, self.hidden_dim)) return h_next # Predict step, as shown in Section 3.3.3 of the paper def _predict(self, h_next, qt): r""" Parameters: h_next: hidden representations of all concepts at the next timestamp after the update step qt: question indices for all students in a batch at the current timestamp Shape: h_next: [batch_size, concept_num, hidden_dim] qt: [batch_size] y: [batch_size, concept_num] Return: y: predicted correct probability of all concepts at the next timestamp """ qt_mask = torch.ne(qt, -1) # [batch_size], qt != -1 y = self.predict(h_next[qt_mask]).squeeze(dim=-1) # [batch_size, concept_num] y[qt_mask] = torch.sigmoid(y[qt_mask]) # [batch_size, concept_num] return y def _get_next_pred(self, yt, q_next): r""" Parameters: yt: predicted correct probability of all concepts at the next timestamp q_next: question index matrix at the next timestamp batch_size: the size of a student batch Shape: y: [batch_size, concept_num] questions: [batch_size, seq_len] pred: [batch_size, ] Return: pred: predicted correct probability of the question answered at the next timestamp """ next_qt = q_next next_qt = torch.where(next_qt != -1, next_qt, self.concept_num * torch.ones_like(next_qt, device=yt.device)) one_hot_qt = F.embedding(next_qt.long(), self.one_hot_q) # [batch_size, concept_num] # dot product between yt and one_hot_qt pred = (yt * one_hot_qt).sum(dim=1) # [batch_size, ] return pred def forward(self, features, questions): r""" Parameters: features: input one-hot matrix questions: question index matrix seq_len dimension needs padding, because different students may have learning sequences with different lengths. Shape: features: [batch_size, seq_len] questions: [batch_size, seq_len] pred_res: [batch_size, seq_len - 1] Return: pred_res: the correct probability of questions answered at the next timestamp """ batch_size, seq_len = features.shape ht = Variable(torch.zeros((batch_size, self.concept_num, self.hidden_dim), device=features.device)) pred_list = [] for i in range(seq_len): xt = features[:, i] # [batch_size] qt = questions[:, i] # [batch_size] qt_mask = torch.ne(qt, -1) # [batch_size], next_qt != -1 tmp_ht = self._aggregate(xt, qt, ht, batch_size) # [batch_size, concept_num, hidden_dim + embedding_dim] h_next = self._update(tmp_ht, ht, qt) # [batch_size, concept_num, hidden_dim] ht[qt_mask] = h_next[qt_mask] # update new ht yt = self._predict(h_next, qt) # [batch_size, concept_num] pred_list.append(self._get_next_pred(yt, questions[:, i + 1])) pred_res = torch.stack(pred_list, dim=1) # [batch_size, seq_len - 1] return pred_res
calculate_shapley_values函数的返回应该是torch tensor,而不是字典。在Python中,字典的使用会相对较慢,且不能在PyTorch的autograd系统中使用。需要将shapley_values字典转换为tensor并返回。
-
将_shapley_values和_aggregate函数移动到GKT类内部,以保证它们能被类中的其他函数正确调用。
-
在_agg_neighbors函数中,shapley_values是通过字典的方式返回的,但在下文中sorted_indices = torch.argsort(shapley_values, descending=True)这一行中,试图将其视为torch tensor。这将引发一个错误,需要将shapley_values转换为torch tensor。
-
在_calculate_shapley_values函数中,对于每个知识点的shapley值的计算,是通过遍历所有可能的邻居集合并计算他们的边际贡献来进行的,这可能会有性能问题,因为它的计算复杂度是指数级的。可能需要寻找一种更有效的方法来计算shapley值。可能考虑的方向:在训练循环外计算Shapley值。
-
对每个函数进行训练前的检查和输出 import torch # 创建示例数据 concept_num = 4 hidden_dim = 10 embedding_dim = 5 edge_type_num = 2 graph_type = 'Dense' features = torch.randn(2, 5) # 示例特征数据,大小为 [batch_size, seq_len] questions = torch.randint(-1, 4, (2, 5)) # 示例问题数据,大小为 [batch_size, seq_len] # 创建模型实例 model = GKT(concept_num, hidden_dim, embedding_dim, edge_type_num, graph_type) # 打印 _aggregate 函数的运算和值 xt = features[:, 0] # 获取第一个时间戳的特征 qt = questions[:, 0] # 获取第一个时间戳的问题索引 ht = torch.zeros((2, concept_num, hidden_dim)) # 初始化隐藏状态 tmp_ht = model._aggregate(xt, qt, ht, batch_size=2) print("_aggregate - tmp_ht:", tmp_ht) # 打印 _agg_neighbors 函数的运算和值 m_next = model._agg_neighbors(tmp_ht, qt) print("_agg_neighbors - m_next:", m_next) # 打印 _update 函数的运算和值 h_next = model._update(tmp_ht, ht, qt) print("_update - h_next:", h_next) # 打印 _predict 函数的运算和值 y = model._predict(h_next, qt) print("_predict - y:", y) # 打印 _get_next_pred 函数的运算和值 q_next = questions[:, 1] # 获取下一个时间戳的问题索引 pred = model._get_next_pred(y, q_next) print("_get_next_pred - pred:", pred)
_aggregate - tmp_ht: tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.6744, -0.2734, 1.4854, -0.7457, 0.0790], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.4134, 0.3520, 1.7249, 0.5040, -1.0225], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.1605, 2.4645, 0.6592, 0.4405, 0.2965], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.3218, -1.0412, 0.5474, 0.4427, 0.1326]], [[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
_aggregate - tmp_ht: tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4616, -0.2127, -0.2274, 0.3314, -0.4540], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1383, 0.4809, -1.5135, 0.6111, 0.1744], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.8623, -0.6276, 0.7693, 0.4770, 1.4637], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -1.2699, -1.6375, -1.5388, -0.1766, 0.9268]], [[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4616, -0.2127, -0.2274, 0.3314, -0.4540], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1383, 0.4809, -1.5135, 0.6111, 0.1744], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -1.2699, -1.6375, -1.5388, -0.1766, 0.9268], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -3.0488, -0.8028, 0.2315, 0.4891, 1.4089]]]) tmp_ht shape: torch.Size([2, 4, 13]) qt shape: torch.Size([2]) tmp_ht dtype: torch.float32 qt dtype: torch.int64
在agg的函数的计算值都可以打印出来 ,打印出来后发现,是在调用agg_neighbor的时候,tmp_ht和qt形状不符合。
-
原模型输出是一样的,所以是调用的时候出现问题
_aggregate - tmp_ht: tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -1.0762, -1.6635, 0.1739, 1.1163, -0.6199], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.1450, -2.3216, 0.8230, 0.5638, -0.5859], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5896, -1.3795, -1.8392, -0.2225, 0.0325], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.4415, 1.6907, -1.4238, -1.1926, -0.3198]], [[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5896, -1.3795, -1.8392, -0.2225, 0.0325], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.1450, -2.3216, 0.8230, 0.5638, -0.5859], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.9347, 1.9458, 1.3726, 0.7222, 1.7240], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.4415, 1.6907, -1.4238, -1.1926, -0.3198]]]) tmp_ht shape: torch.Size([2, 4, 13]) qt shape: torch.Size([2]) tmp_ht dtype: torch.float32 qt dtype: torch.int64 _agg_neighbors - m_next: (tensor([[[-103.3577, -65.4964, 133.9530, -342.6796, -42.5984, -83.8122, -266.8670, -249.4939], [ 138.8846, 11.7192, 67.9875, -78.8893, -197.2776, 50.4039, -124.7583, 61.5432], [ 0.9999, -0.9997, 0.0000, -0.9999, -1.0000, -1.0000, 1.0000, 0.0000], [ -34.7644, 104.6367, -283.5654, 108.0663, 88.5537, -296.1309, 56.2663, 61.5432]], [[ -0.9999, 0.9997, 0.0000, 0.9999, 1.0000, 1.0000, -1.0000, 0.0000], [-107.6568, -212.8599, 67.9881, -117.1709, 44.4904, 50.4043, -197.5228, 61.5438], [ 62.4717, -186.7843, -252.6718, 53.6039, 101.8955, -17.9916, 203.1964, -399.3390], [ 6.9311, -42.5710, 15.1827, 63.6281, 88.5530, -6.2922, 91.5724, 61.5428]]], grad_fn=<AsStridedBackward0>), None, None, None, tensor([[[-1.1346e+02], [ 9.2065e-43], [-1.1346e+02], [ 9.2065e-43]], [[-1.1346e+02], [ 9.2065e-43], [-1.1346e+02], [ 9.2065e-43]]]), tensor([[[-113.4637], [-113.4637], [-113.4637], [-113.4637]], [[-113.4637], [-113.4648], [-113.4637], [-113.4629]]]))
新的调试内容: # 模型初始化 concept_num = 10 hidden_dim = 32 embedding_dim = 64 edge_type_num = 2 graph_type = "Dense" graph = np.zeros((concept_num, concept_num)) graph_model = None dropout = 0.5 bias = True binary = False has_cuda = False model = GKT(concept_num, hidden_dim, embedding_dim, edge_type_num, graph_type, graph, graph_model, dropout, bias, binary, has_cuda) # 模拟输入数据 batch_size = 5 seq_len = 20 features = torch.randint(0, 2, (batch_size, seq_len)) # 二值化输入 questions = torch.randint(-1, concept_num, (batch_size, seq_len)) # 随机问题 # 打印中间变量的形状和值 # 打印 _aggregate 函数中的变量 print("In _aggregate function:") tmp_ht = model._aggregate(features[:, 0], questions[:, 0], torch.zeros((batch_size, model.concept_num, model.hidden_dim)), batch_size) print("tmp_ht shape: ", tmp_ht.shape) print("tmp_ht value: ", tmp_ht) # 打印 _agg_neighbors 函数中的变量 print("In _agg_neighbors function:") m_next = model._agg_neighbors(tmp_ht, questions[:, 0]) for i, item in enumerate(m_next): print(f"m_next[{i}] value: ", item) # 打印 _update 函数中的变量 print("In _update function:") h_next = model._update(tmp_ht, torch.zeros((batch_size, model.concept_num, model.hidden_dim)), questions[:, 0]) print("h_next shape: ", h_next.shape) print("h_next value: ", h_next) # 打印 _predict 函数中的变量 print("In _predict function:") yt = model._predict(h_next, questions[:, 0]) print("yt shape: ", yt.shape) print("yt value: ", yt) # 打印 _get_next_pred 函数中的变量 print("In _get_next_pred function:") pred = model._get_next_pred(yt, questions[:, 1]) print("pred shape: ", pred.shape) print("pred value: ", pred) 结果输出: In _aggregate function: tmp_ht shape: torch.Size([5, 10, 96]) tmp_ht value: tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0941, -0.4533, -1.0247], [ 0.0000, 0.0000, 0.0000, ..., -0.4808, 1.2499, 0.3105], [ 0.0000, 0.0000, 0.0000, ..., 0.7841, 0.4550, 0.3469], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.5375, 0.2919, 1.0231], [ 0.0000, 0.0000, 0.0000, ..., -0.2228, 1.8833, 1.1549], [ 0.0000, 0.0000, 0.0000, ..., -0.8485, 0.6319, 0.5914]], [[ 0.0000, 0.0000, 0.0000, ..., 0.0941, -0.4533, -1.0247], [ 0.0000, 0.0000, 0.0000, ..., -2.6132, -1.0803, 0.3215], [ 0.0000, 0.0000, 0.0000, ..., 1.3176, -0.8746, -0.7052], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.5375, 0.2919, 1.0231], [ 0.0000, 0.0000, 0.0000, ..., -0.2228, 1.8833, 1.1549], [ 0.0000, 0.0000, 0.0000, ..., -0.8485, 0.6319, 0.5914]], [[ 0.0000, 0.0000, 0.0000, ..., 0.0941, -0.4533, -1.0247], [ 0.0000, 0.0000, 0.0000, ..., -2.6132, -1.0803, 0.3215], [ 0.0000, 0.0000, 0.0000, ..., 0.7841, 0.4550, 0.3469], ..., [ 0.0000, 0.0000, 0.0000, ..., 1.3176, -0.8746, -0.7052], [ 0.0000, 0.0000, 0.0000, ..., -0.2228, 1.8833, 1.1549], [ 0.0000, 0.0000, 0.0000, ..., -0.8485, 0.6319, 0.5914]], [[ 0.0000, 0.0000, 0.0000, ..., 0.0941, -0.4533, -1.0247], [ 0.0000, 0.0000, 0.0000, ..., -2.6132, -1.0803, 0.3215], [ 0.0000, 0.0000, 0.0000, ..., 0.7841, 0.4550, 0.3469], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.5375, 0.2919, 1.0231], [ 0.0000, 0.0000, 0.0000, ..., -0.2228, 1.8833, 1.1549], [ 0.0000, 0.0000, 0.0000, ..., -0.8485, 0.6319, 0.5914]], [[ 0.0000, 0.0000, 0.0000, ..., 0.0941, -0.4533, -1.0247], [ 0.0000, 0.0000, 0.0000, ..., -2.6132, -1.0803, 0.3215], [ 0.0000, 0.0000, 0.0000, ..., 0.7841, 0.4550, 0.3469], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.5375, 0.2919, 1.0231], [ 0.0000, 0.0000, 0.0000, ..., -0.4808, 1.2499, 0.3105], [ 0.0000, 0.0000, 0.0000, ..., -0.8485, 0.6319, 0.5914]]], grad_fn=<CatBackward0>) In _agg_neighbors function: m_next[0] value: tensor([[[ 1.2324e+32, 2.2032e+31, -3.2379e+31, ..., 7.7117e+30, 4.6696e+31, 2.1787e+30], [ 5.7271e-01, 2.7850e-01, 0.0000e+00, ..., -7.0925e-01, 1.0276e-01, -4.9551e-01], [ 9.5813e+30, -2.6708e+30, -1.8207e+30, ..., -2.0225e+30, -3.1715e+30, -4.6428e+30], ..., [-6.0195e-12, 2.1256e-11, 7.5524e-12, ..., -4.9179e-12, -5.0023e-12, -4.3686e-12], [ 1.2946e+29, 1.5078e+30, 4.0381e+29, ..., 6.7391e+29, -7.6972e+29, -5.7277e+29], [-1.3544e+31, -1.1288e+31, -8.5551e+30, ..., 6.5880e+30, 1.5617e+31, -1.7933e+31]], [[-2.7344e+35, -4.0952e+35, -8.8878e+34, ..., -1.7149e+35, -2.4635e+35, -1.5102e+35], [ 1.6327e+31, -4.4965e+30, 2.6258e+30, ..., -2.6145e+30, 1.1917e+31, 7.0686e+30], [ 3.9246e-01, 1.3375e+00, 0.0000e+00, ..., -7.0925e-01, -1.5798e+00, -4.9551e-01], ..., [-5.3433e+37, -4.4531e+37, -3.3750e+37, ..., -2.3277e+37, -4.3370e+37, 1.0412e+38], [ 4.0599e+35, -3.8431e+35, -4.4264e+35, ..., -1.6095e+35, 1.4819e+35, -1.4174e+35], [-9.9087e+30, -8.2615e+30, -6.2618e+30, ..., -1.3278e+31, -8.6847e+30, -1.0046e+31]], [[-1.3277e+28, -1.1065e+28, -8.3865e+27, ..., -2.1175e+28, -1.1633e+28, -1.7579e+28], [-6.5295e-12, -5.3621e-12, 1.8268e-11, ..., 3.0528e-12, -5.7491e-12, 2.1346e-12], [ 4.8472e+37, -9.6165e+37, -6.2913e+37, ..., 1.1887e+38, -5.7849e+37, 3.2657e+38], ..., [-8.5953e-01, -1.6182e+00, 0.0000e+00, ..., 1.9529e+00, 1.4658e+00, -4.9551e-01], [ 1.6460e+31, 2.4490e+31, -5.6706e+30, ..., -7.9908e+30, -8.7314e+30, -1.3150e+31], [ 3.0621e+23, -5.1679e+24, -1.3154e+24, ..., -2.5616e+24, -2.5818e+24, -2.1776e+24]], [[ 1.8190e+28, 2.3466e+28, -8.6000e+27, ..., -2.1715e+28, -7.2281e+27, 1.5122e+28], [-4.7508e+31, -1.6147e+31, -2.8254e+31, ..., -1.2579e+31, -9.1815e+28, 5.5071e+31], [-1.5965e+28, -1.6241e+27, 2.2543e+28, ..., -1.0016e+28, -1.4384e+28, -8.8153e+27], ..., [-2.3707e+22, -9.3314e+21, -9.2592e+22, ..., -3.3664e+22, -4.8358e+22, 9.8055e+22], [-1.8325e+32, 8.0848e+32, 9.6632e+32, ..., -6.5985e+30, 2.7459e+31, 6.1748e+30], [ 4.1620e+31, -1.3489e+31, -4.9357e+30, ..., -1.0167e+31, 4.8233e+30, -8.0429e+30]], [[-1.6744e-19, -1.0456e-20, 2.2607e-20, ..., -1.0501e-19, -1.5085e-19, -9.2480e-20], [-8.6299e+29, -7.1921e+29, -5.4509e+29, ..., 1.2944e+30, -7.5612e+29, 7.3287e+29], [-2.1354e+35, -1.7795e+35, -1.3487e+35, ..., -3.4056e+35, -1.8710e+35, -2.8273e+35], ..., [ 2.1699e+30, -1.9961e+30, 1.2040e+31, ..., -7.4784e+30, 4.2361e+30, -6.6373e+30], [-1.4264e+00, -5.0625e-01, 0.0000e+00, ..., -4.0362e-01, 4.2081e-01, -4.9551e-01], [ 2.1760e+12, -2.3845e+12, 3.4466e+12, ..., -8.9708e+11, 2.3981e+12, -3.1746e+11]]], grad_fn=<AsStridedBackward0>) m_next[1] value: None m_next[2] value: None m_next[3] value: None m_next[4] value: tensor([[[7.1061e+31], [4.2964e+24], [4.7824e+30], [9.7374e+15], [4.7824e+30], [6.1126e-02], [4.7824e+30], [3.7386e-14], [1.6109e-19], [1.8888e+31]], [[6.6660e-33], [2.1470e+29], [6.7111e+22], [4.5450e+30], [1.8524e+28], [1.9519e-19], [7.2708e+31], [7.4513e+37], [1.8888e+31], [1.3821e+31]], [[1.8515e+28], [9.1041e-12], [6.2609e+22], [4.7428e+30], [6.2288e+22], [2.1974e+23], [8.6758e-04], [4.7429e+30], [1.3818e+31], [6.7724e+22]], [[1.8987e+28], [7.1061e+31], [4.2964e+24], [3.0607e+32], [6.6134e+19], [1.0724e+31], [4.4657e+30], [1.5648e+01], [3.0607e+32], [1.6533e+19]], [[1.0804e-32], [1.2034e+30], [2.9777e+35], [1.6408e+07], [3.0607e+32], [7.3904e+22], [1.7239e+25], [6.3828e+28], [1.4348e-19], [2.7530e+12]]]) m_next[5] value: tensor([[[1.0486e+31], [4.2964e+24], [2.1470e+29], [1.0644e+31], [7.1061e+31], [6.2609e+22], [4.7428e+30], [9.1041e-12], [1.2034e+30], [5.6542e+05]], [[3.1731e+35], [4.7824e+30], [6.7111e+22], [4.4657e+30], [4.2964e+24], [4.7428e+30], [3.0607e+32], [6.2609e+22], [2.9777e+35], [1.1614e+27]], [[6.6660e-33], [3.7386e-14], [7.4513e+37], [1.8888e+31], [1.5648e+01], [4.9651e+28], [3.0607e+32], [4.7429e+30], [6.3828e+28], [4.6165e+24]], [[3.7404e-14], [4.7824e+30], [1.8524e+28], [2.8298e+20], [6.6134e+19], [2.7259e+20], [1.0724e+31], [6.2288e+22], [3.0607e+32], [1.9205e+31]], [[1.9431e-19], [1.6109e-19], [1.8888e+31], [6.6660e-33], [3.0607e+32], [3.6002e+27], [7.3904e+22], [1.3818e+31], [1.4348e-19], [4.3701e+12]]]) In _update function: # 说明在update函数这里出现了问题 Traceback (most recent call last): File "D:\N\Code\anaconda3\envs\pytorch\lib\site-packages\IPython\core\interactiveshell.py", line 3460, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "<ipython-input-2-cc829fe010f7>", line 1, in <module> runfile('D:/N/Code/Repository/GKT/GKT-master/Prin.py', wdir='D:/N/Code/Repository/GKT/GKT-master') File "D:\N\Code\JetBrains\PyCharm 2022.2.1\plugins\python\helpers\pydev\_pydev_bundle\pydev_umd.py", line 198, in runfile pydev_imports.execfile(filename, global_vars, local_vars) # execute the script File "D:\N\Code\JetBrains\PyCharm 2022.2.1\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile exec(compile(contents+"\n", file, 'exec'), glob, loc) File "D:/N/Code/Repository/GKT/GKT-master/Prin.py", line 54, in <module> h_next = model._update(tmp_ht, torch.zeros((batch_size, model.concept_num, model.hidden_dim)), questions[:, 0]) File "D:\N\Code\Repository\GKT\GKT-master\models.py", line 193, in _update m_next, concept_embedding, rec_embedding, z_prob = self._agg_neighbors(tmp_ht, qt) # [batch_size, concept_num, hidden_dim] ValueError: too many values to unpack (expected 4)
原模型的正常输出: In _aggregate function: tmp_ht shape: torch.Size([5, 10, 96]) tmp_ht value: tensor([[[ 0.0000, 0.0000, 0.0000, ..., -1.3486, 0.8760, 0.2503], [ 0.0000, 0.0000, 0.0000, ..., -0.5594, -0.1822, -0.8547], [ 0.0000, 0.0000, 0.0000, ..., 0.3173, 1.0626, -1.6663], ..., [ 0.0000, 0.0000, 0.0000, ..., -0.9511, -0.0559, 1.0936], [ 0.0000, 0.0000, 0.0000, ..., 0.0359, -0.5671, 1.5460], [ 0.0000, 0.0000, 0.0000, ..., -0.8684, -2.4842, -0.1697]], [[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[ 0.0000, 0.0000, 0.0000, ..., -0.2469, -1.3631, -1.2307], [ 0.0000, 0.0000, 0.0000, ..., -0.5594, -0.1822, -0.8547], [ 0.0000, 0.0000, 0.0000, ..., 0.3173, 1.0626, -1.6663], ..., [ 0.0000, 0.0000, 0.0000, ..., -0.9511, -0.0559, 1.0936], [ 0.0000, 0.0000, 0.0000, ..., 0.0359, -0.5671, 1.5460], [ 0.0000, 0.0000, 0.0000, ..., -0.8684, -2.4842, -0.1697]], [[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[ 0.0000, 0.0000, 0.0000, ..., -0.2469, -1.3631, -1.2307], [ 0.0000, 0.0000, 0.0000, ..., 0.0381, -0.4959, 0.2106], [ 0.0000, 0.0000, 0.0000, ..., 0.3173, 1.0626, -1.6663], ..., [ 0.0000, 0.0000, 0.0000, ..., -0.9511, -0.0559, 1.0936], [ 0.0000, 0.0000, 0.0000, ..., 0.0359, -0.5671, 1.5460], [ 0.0000, 0.0000, 0.0000, ..., -0.8684, -2.4842, -0.1697]]], grad_fn=<CatBackward0>) In _agg_neighbors function: m_next[0] value: tensor([[-7.0708e-01, -7.0699e-01, 0.0000e+00, 1.4142e+00, 1.4129e+00, -1.9984e-01, 1.4142e+00, -7.0710e-01, 1.4142e+00, 6.2329e-01, -9.9698e-01, -7.0703e-01, 1.1503e+00, -1.3507e+00, -1.0090e+00, -6.8545e-01, 0.0000e+00, -7.0710e-01, 6.7942e-01, 0.0000e+00, -7.0706e-01, -9.9652e-01, -8.1369e-01, -1.4119e+00, -7.0707e-01, 1.4030e+00, 1.4140e+00, -7.0710e-01, -7.6625e-01, 1.3715e+00, 0.0000e+00, 1.4132e+00], [-9.3377e+30, -4.4438e+30, 4.0269e+31, 5.2966e+31, 4.7307e+31, -1.7204e+31, 6.4155e+31, -2.0324e+31, 3.4213e+31, 9.8819e+30, -1.2913e+31, 1.0009e+30, 1.9279e+31, -3.6524e+30, 3.0597e+31, -1.1375e+31, -1.3153e+31, -1.1106e+31, 4.4666e+31, -1.4089e+31, 2.0475e+31, -6.9232e+30, -8.8107e+30, 7.2887e+29, -6.8600e+30, -7.0865e+30, -1.5513e+31, -7.7240e+30, -1.1357e+31, -1.9777e+31, -2.3097e+31, -1.2965e+31], [ 6.6746e+19, -1.7142e+20, -1.3694e+20, 4.0207e+20, -8.6581e+19, -1.6775e+20, 3.6503e+20, -1.9816e+20, -7.6050e+19, -3.6486e+19, 1.1670e+20, 1.7826e+20, -1.5121e+20, -3.5612e+19, 2.8137e+20, -2.2191e+20, -1.2824e+20, 3.1863e+20, 2.2312e+20, 9.4794e+19, 4.0157e+19, 6.3880e+19, 4.0683e+19, -1.2084e+20, -1.2029e+20, -6.9094e+19, 1.0492e+19, -7.5310e+19, -5.4612e+18, -1.8008e+20, -1.6469e+20, -1.2642e+20], [-7.3726e+22, -8.1940e+22, 1.3152e+23, -6.4019e+22, 1.5749e+23, -2.7702e+22, -3.2767e+22, 2.7229e+22, 1.1503e+23, -1.6409e+22, 2.9096e+22, 1.3772e+23, 1.7966e+22, -1.3552e+22, -5.2015e+22, -1.7332e+22, -4.8803e+22, -4.1206e+22, 1.4981e+23, 7.3645e+22, -7.3938e+22, -2.5688e+22, -3.2691e+22, 2.3442e+23, -4.5774e+22, -2.6294e+22, -5.7559e+22, -2.8659e+22, -4.2138e+22, 4.5219e+21, -5.0129e+22, 8.8478e+21], [-8.7262e+22, 1.0173e+22, 3.7276e+22, 2.9179e+23, -5.0749e+22, 3.8675e+22, 5.4565e+22, -4.8569e+22, -2.1210e+22, -4.0344e+22, 6.0943e+22, -3.7366e+22, 1.0264e+23, -5.1681e+22, -5.3926e+22, -5.2226e+22, 1.1072e+23, -4.4829e+22, -3.5410e+22, 1.1655e+21, -2.8266e+22, -1.0642e+23, -5.6012e+22, -3.2218e+22, -6.5300e+22, -2.0264e+22, 3.0176e+22, -1.3118e+22, 2.3413e+23, -6.9852e+22, -3.1340e+22, -4.0416e+22], [ 6.8033e+30, -1.6109e+31, -2.3774e+30, -1.0658e+31, -1.7291e+31, -9.6644e+30, 3.5310e+31, -1.2201e+31, -2.0484e+31, 2.6257e+31, 2.2330e+31, -9.3779e+30, -1.4104e+31, -1.2991e+31, 2.8924e+31, 1.8616e+31, -1.8885e+31, 2.9299e+31, -8.8921e+30, -1.0495e+31, -7.1990e+30, -9.6644e+30, 8.4747e+30, -8.0738e+30, 1.2322e+31, -5.0799e+30, -1.9082e+31, -3.2807e+30, -7.3867e+30, 3.3135e+31, -7.8981e+30, 5.3422e+30], [ 8.1123e+30, -6.2705e+31, 1.1285e+31, -4.3019e+31, 5.3195e+31, -7.1557e+31, -1.1452e+31, -4.9246e+31, -8.2678e+31, -4.0961e+31, 1.0570e+28, -3.7851e+31, 8.4043e+31, -5.2434e+31, 5.9398e+31, -5.3062e+31, -7.6225e+31, 1.8598e+31, -3.5890e+31, -4.2359e+31, -2.9057e+31, -5.8872e+29, -5.6780e+31, -3.2588e+31, -6.6175e+31, -2.0504e+31, -2.1027e+31, -1.3242e+31, -2.9815e+31, 5.7247e+31, -3.1878e+31, -4.0905e+31], [ 1.4829e-19, 1.0967e-19, 7.7588e-19, 7.9435e-20, -2.2232e-21, -9.9126e-20, 4.5317e-19, 9.7788e-19, 1.6251e-19, 3.0213e-19, 7.8386e-20, -9.8366e-20, 4.6539e-19, -1.8116e-19, 6.1422e-20, 6.9460e-19, -1.0671e-19, -2.1518e-19, -1.2380e-19, -2.2600e-19, -2.2842e-19, -3.6877e-19, -1.7340e-19, -1.8498e-19, 1.8190e-19, -1.1130e-19, 5.2681e-21, -9.4388e-20, 7.1995e-19, 6.4983e-20, -5.4982e-20, 1.8741e-21], [-3.7241e+31, -3.3337e+32, -2.3286e+32, -2.8151e+32, 1.0224e+33, 6.4429e+31, -1.3606e+32, -5.9563e+31, 5.3538e+31, 1.4440e+32, -1.9856e+32, -1.4239e+32, -2.4505e+32, -6.3223e+31, -2.3303e+32, -2.3442e+30, -2.1750e+32, -1.8454e+32, -1.9125e+32, 5.1497e+32, -1.0966e+32, -1.2042e+32, -1.4346e+32, -2.0468e+32, -1.4783e+32, -1.8087e+31, -1.2408e+32, 4.9439e+32, -1.8298e+32, -3.4235e+31, -1.4410e+32, -2.1459e+32], [-4.8799e+34, -7.5724e+35, -5.8648e+35, -5.0059e+35, 5.1855e+34, 5.8840e+35, -2.6729e+35, -5.7328e+35, -5.1632e+32, -4.7714e+35, 4.9774e+35, 1.4959e+36, -9.3555e+35, 1.1309e+36, 1.0021e+36, -6.1821e+35, 8.0353e+35, 1.6282e+35, -4.1733e+35, 1.0937e+36, -3.3849e+35, 8.5451e+35, 1.5014e+36, -3.7940e+35, -7.7074e+35, -2.3773e+35, -1.3762e+36, -1.5430e+35, -3.4728e+35, 1.4207e+36, -3.7155e+35, -4.7650e+35]], grad_fn=<UnbindBackward0>) m_next[1] value: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<UnbindBackward0>) m_next[2] value: tensor([[ 3.2706e+22, 6.1654e+22, -1.0289e+22, 2.6162e+20, 9.3671e+22, -9.1952e+21, 4.4778e+22, 9.0740e+22, 3.7636e+22, -4.0345e+22, -6.6226e+22, -3.7282e+22, 2.4758e+22, -5.1646e+22, -3.9467e+21, -3.0683e+22, -3.1847e+22, -4.5224e+22, -3.5351e+22, -4.1722e+22, -2.8620e+22, 4.5490e+22, -2.5942e+22, 2.8234e+22, 1.0630e+23, -2.0195e+22, 3.7853e+22, -1.3042e+22, -2.9366e+22, 6.0719e+22, -3.1399e+22, -4.0290e+22], [-4.5834e+36, -2.6728e+36, -1.0301e+37, -1.2655e+37, 8.3586e+36, 1.1143e+37, -6.4773e+36, -1.4907e+37, -9.6596e+36, -8.7622e+36, -9.4710e+36, -6.2777e+36, -1.1375e+37, -2.6789e+36, 3.5390e+37, -1.6662e+37, 8.2350e+36, -8.1455e+36, -8.4770e+36, 2.9643e+36, 1.1818e+37, -5.0779e+36, -6.4623e+36, 1.6043e+37, -3.4187e+36, -4.9023e+36, 1.6553e+37, -5.4145e+36, 7.5373e+36, -1.3946e+37, -5.9986e+35, 9.9735e+36], [-5.9961e+29, 6.3656e+30, -2.5111e+30, -6.2818e+29, -2.2669e+30, -1.6721e+30, 8.9544e+30, -3.6975e+30, -3.3621e+30, -1.9707e+30, 1.0745e+31, -1.1151e+30, -3.8168e+30, -9.3241e+29, 9.3388e+30, -1.5307e+30, 2.3235e+30, -2.8351e+30, -2.9505e+30, -3.3862e+30, -5.0872e+30, 1.3736e+30, 1.0323e+31, -3.1638e+30, -9.5350e+29, -1.8091e+30, -3.4885e+30, -1.9718e+30, 1.2260e+31, -5.0489e+30, -3.4989e+30, 4.1591e+29], [-7.0708e-01, -7.0699e-01, 0.0000e+00, -7.0710e-01, -7.0646e-01, 1.3124e+00, -7.0709e-01, 1.4142e+00, -7.0708e-01, -1.4110e+00, 1.3671e+00, -7.0703e-01, 1.3721e-01, 1.0381e+00, 1.3627e+00, -7.2843e-01, 0.0000e+00, 1.4142e+00, -1.4138e+00, 0.0000e+00, -7.0706e-01, 1.3673e+00, -5.9481e-01, 6.3807e-01, -7.0707e-01, -5.4807e-01, -7.0702e-01, 1.4142e+00, 1.4125e+00, -9.8435e-01, 0.0000e+00, -7.0660e-01], [-4.9704e-20, 1.9283e-19, -1.7453e-19, 8.8613e-20, -7.9543e-20, -2.2866e-19, -1.3696e-19, 9.6960e-21, -2.1970e-20, -8.3595e-20, -7.8045e-20, -1.3407e-19, -2.1583e-19, -1.1921e-19, -2.2210e-19, -2.1614e-20, 4.8845e-21, -1.1502e-19, -8.6453e-20, -1.8227e-19, -1.8023e-19, -1.7000e-19, 1.5182e-19, -1.5134e-19, -5.4573e-20, -9.0137e-20, -1.6925e-19, -8.0708e-20, -1.3859e-19, -8.9996e-20, 7.3867e-20, 9.5617e-20], [ 4.4468e+26, -4.9025e+26, -2.3940e+27, 4.2447e+27, -2.0231e+27, 7.5912e+27, -2.0120e+27, 3.2536e+27, -3.0005e+27, -2.7218e+27, -2.9420e+27, -1.5690e+27, 1.0161e+27, -8.3214e+26, -3.1939e+27, 4.0000e+27, -1.9315e+26, -2.5302e+27, -2.6332e+27, -3.2098e+27, -1.7471e+27, -1.5773e+27, -2.0074e+27, -2.8236e+27, 1.1694e+28, -1.6145e+27, -1.2657e+27, -1.7598e+27, -2.5874e+27, 3.8709e+27, 5.5685e+27, -2.9539e+27], [-1.9350e+31, -9.7497e+30, -5.1880e+30, 1.4319e+30, 2.6939e+31, -8.2378e+30, -8.6002e+30, 6.0057e+30, -1.2825e+31, -1.1069e+31, 2.9593e+30, -8.3351e+30, -1.5103e+31, -3.5569e+30, 2.2506e+31, -4.6303e+30, 2.5249e+31, -2.4485e+30, -1.1255e+31, -2.9802e+30, -7.2700e+30, -6.7421e+30, -8.5802e+30, -1.0575e+31, 2.9042e+31, -6.9011e+30, 5.8356e+30, 1.3281e+31, 3.3284e+30, -1.9260e+31, 4.2243e+30, 6.4633e+30], [ 2.2739e+28, 1.2526e+28, -1.3374e+28, -1.6430e+28, -8.4561e+27, 5.1880e+27, -8.4096e+27, -1.3175e+28, -4.6555e+27, 1.0422e+28, -1.2296e+28, -8.1504e+27, 1.3708e+28, -3.4781e+27, -1.9735e+27, 1.1638e+28, 4.6767e+28, -1.0575e+28, -1.1006e+28, -1.3416e+28, 3.8155e+27, -6.5927e+27, -8.3901e+27, -5.7155e+27, 1.8367e+28, 6.5183e+28, -1.1263e+28, -7.3553e+27, 3.5261e+27, -5.0872e+27, 6.7944e+27, -1.2347e+28], [-1.8750e+19, -2.0840e+19, -1.3253e+19, 8.0226e+18, -8.3797e+18, -4.4038e+18, -8.3336e+18, 1.3502e+18, -1.2428e+19, -9.1538e+18, -1.2185e+19, -8.0768e+18, -6.0951e+18, 1.4082e+17, -1.3229e+19, -1.6047e+19, -7.9953e+17, -1.0480e+19, -1.0906e+19, -1.3295e+19, 1.6455e+19, -6.5332e+18, -8.3143e+18, -9.4446e+18, 2.4920e+18, -6.6872e+18, -1.1521e+19, -7.2887e+18, -1.0717e+19, -1.7708e+19, -1.1502e+19, -1.2235e+19], [ 6.7204e+22, 1.1744e+23, 4.1223e+22, -4.2369e+22, 9.4173e+22, -1.9247e+22, 4.0880e+22, -2.8197e+22, 8.4356e+22, -4.0342e+22, -4.7019e+22, -3.7279e+22, 2.9378e+22, 9.9465e+22, 1.5250e+22, -5.2259e+22, 5.1191e+22, -4.5220e+22, -3.5348e+22, -4.1718e+22, -2.8617e+22, 1.2036e+23, 1.8225e+22, -3.2095e+22, 1.2533e+23, -2.0194e+22, 1.1872e+23, -1.3041e+22, -2.9364e+22, 1.4555e+23, -3.1396e+22, 7.3005e+22]], grad_fn=<UnbindBackward0>) m_next[3] value: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<UnbindBackward0>) m_next[4] value: tensor([[ 2.2705e+31, 1.4660e+31, -1.2587e+31, -9.2784e+30, 1.5091e+31, 5.1478e+30, -1.5791e+31, -8.0918e+30, 2.5169e+31, -1.0873e+31, -1.1210e+31, -4.5892e+30, 3.8851e+30, -2.5928e+30, -1.6804e+31, -6.3257e+30, 1.0203e+31, -1.2188e+31, -9.5274e+30, -1.1244e+31, -7.7133e+30, 2.1364e+31, 2.1834e+30, 2.7694e+31, -1.7567e+31, -5.4429e+30, 1.4351e+31, -3.5151e+30, -7.9145e+30, 4.2026e+30, -8.4624e+30, -1.0859e+31], [ 1.4142e+00, 1.4140e+00, 0.0000e+00, -7.0710e-01, -7.0646e-01, -1.1125e+00, -7.0709e-01, -7.0710e-01, -7.0708e-01, 7.8772e-01, -3.7009e-01, 1.4141e+00, -1.2875e+00, 3.1256e-01, -3.5368e-01, 1.4139e+00, 0.0000e+00, -7.0710e-01, 7.3442e-01, 0.0000e+00, 1.4141e+00, -3.7075e-01, 1.4085e+00, 7.7387e-01, 1.4141e+00, -8.5493e-01, -7.0702e-01, -7.0710e-01, -6.4625e-01, -3.8719e-01, 0.0000e+00, -7.0660e-01], [-8.7480e+22, -6.4217e+22, 4.4588e+22, -4.2440e+22, -6.4291e+22, -8.7220e+21, -5.6982e+22, -4.8785e+22, -8.1393e+22, -4.0508e+22, -5.8175e+22, -8.1076e+21, -2.7980e+22, 4.6328e+21, -6.2220e+22, -5.2296e+22, -7.5255e+22, -4.4817e+22, 1.2510e+23, 8.5972e+22, 3.0243e+22, -4.6261e+22, -5.5230e+22, -3.2242e+22, 7.4183e+22, -2.0292e+22, 6.4002e+22, -1.3149e+22, -2.8942e+22, -4.9838e+22, -3.1719e+22, -4.0467e+22], [-1.7277e+37, -1.2660e+37, 1.8427e+36, -7.3565e+36, 1.9082e+36, -8.2479e+36, -6.5535e+36, 1.2388e+36, -1.6098e+37, -7.9752e+36, -6.8664e+36, -7.3698e+36, -1.5649e+37, -3.7562e+36, -1.2325e+37, -1.0331e+37, -7.9551e+36, -8.9396e+36, -6.9880e+36, 1.2858e+37, -5.6574e+36, -1.3873e+37, -3.6395e+35, -4.3316e+36, 1.2969e+37, -3.9921e+36, -3.9011e+36, -2.5782e+36, -5.1007e+36, 1.0325e+37, -6.2068e+36, 4.6018e+36], [-2.7712e+20, -3.0800e+20, -1.9526e+20, 1.4188e+20, -1.2385e+20, -2.3994e+20, -1.2317e+20, 3.3179e+19, -1.8368e+20, -1.2251e+19, 9.6322e+19, -9.1869e+19, -6.4163e+19, -5.0939e+19, -1.6825e+20, -6.8739e+19, -1.8344e+20, -1.5489e+20, -1.6119e+20, 3.9078e+20, -2.7792e+20, -9.6556e+19, -1.2288e+20, -7.7369e+19, 1.9579e+20, 8.1454e+19, -7.2322e+19, 3.3365e+20, -2.7567e+19, -1.5337e+20, -5.7041e+19, 3.5530e+20], [ 7.5070e-20, 2.1309e-20, 5.7505e-20, -1.3723e-20, -2.3371e-19, 2.0968e-19, -1.3830e-19, -1.4768e-19, -2.4014e-19, 1.4322e-19, -2.4404e-19, -1.3306e-19, 1.8781e-21, -5.3809e-20, -2.2042e-19, 9.1289e-20, 1.1045e-19, 1.5613e-20, 3.1851e-19, -1.0209e-19, 1.9921e-19, 3.0032e-20, -1.2337e-19, 4.1138e-19, 9.0199e-20, 8.9611e-20, 2.1762e-19, -8.0373e-20, -1.2087e-19, -8.6872e-20, 2.1440e-19, -2.7865e-20], [ 1.3431e+27, 6.1703e+27, -3.1962e+27, 2.6143e+26, -2.1683e+27, -1.8776e+27, -2.1564e+27, -9.0900e+26, -3.2158e+27, -2.4771e+26, 7.8160e+27, 7.5658e+27, -3.7869e+27, -8.9185e+26, -3.4231e+27, -5.5573e+27, -3.2117e+27, -2.7118e+27, -2.8221e+27, -2.1575e+27, 1.7284e+26, -1.6905e+27, 1.2424e+28, -3.0262e+27, -2.7738e+27, -1.7304e+27, -3.7879e+27, -1.8860e+27, -2.7731e+27, 4.3988e+27, -1.3221e+26, -3.1659e+27], [ 4.3761e+28, 1.2717e+29, -4.1877e+28, -2.9479e+28, -3.5374e+28, -6.8535e+28, -3.5180e+28, 9.7419e+28, -5.2463e+28, 2.4630e+29, -5.1439e+28, -3.4095e+28, -4.8961e+28, -1.4550e+28, -5.5844e+28, 1.3153e+29, 2.2546e+29, -4.4240e+28, -4.6040e+28, -5.6123e+28, -7.9381e+28, -2.7579e+28, -3.5098e+28, -4.9369e+28, -4.9144e+28, 2.0273e+28, -6.1796e+28, -3.0769e+28, 3.9149e+28, 5.4236e+28, 1.4430e+29, -5.1649e+28], [ 3.2065e+25, 2.0080e+25, 6.7749e+24, -1.0482e+25, 1.4335e+25, 2.7997e+24, -7.9242e+24, 9.4052e+24, 1.7527e+25, 4.9195e+25, -3.4844e+24, -9.2233e+24, 3.5161e+25, -1.2777e+25, 2.5923e+25, 1.9148e+25, -9.3156e+24, 4.4959e+25, 1.2142e+25, -1.0321e+25, -6.3326e+24, 1.6733e+25, -1.3836e+25, -7.9407e+24, -1.3955e+25, -4.9962e+24, 4.0984e+24, -3.2266e+24, -7.2649e+24, -1.7253e+25, 2.3038e+25, -9.9671e+24], [-7.6128e+30, 5.4335e+31, 6.2959e+31, -8.0251e+30, 2.4663e+31, -1.5642e+31, -3.7694e+30, 2.1301e+31, 2.1301e+31, -2.9152e+31, -3.9283e+31, -2.6939e+31, -9.8384e+30, 3.8949e+31, 5.1939e+31, -2.0039e+31, 5.2509e+31, 1.6385e+31, -2.5543e+31, -3.0147e+31, -2.0680e+31, 8.6417e+31, -2.2322e+30, -2.3193e+31, 7.7564e+31, -1.4592e+31, 2.1713e+31, -9.4241e+30, -1.4558e+30, 6.8830e+31, -2.2688e+31, 2.3506e+31]], grad_fn=<UnbindBackward0>) In _update function: h_next shape: torch.Size([5, 10, 32]) h_next value: tensor([[[ 0.1248, 0.1779, 0.1606, ..., -0.0967, 0.0597, 0.0568], [ 0.0207, -0.0659, 0.0029, ..., 0.0385, -0.0248, -0.0389], [ 0.0420, -0.0477, 0.0094, ..., 0.0488, -0.0487, -0.0332], ..., [ 0.0400, -0.0495, 0.0088, ..., 0.0478, -0.0464, -0.0337], [ 0.0231, -0.0639, 0.0036, ..., 0.0396, -0.0274, -0.0382], [ 0.0341, -0.0546, 0.0070, ..., 0.0449, -0.0398, -0.0353]], [[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[ 0.0258, -0.0616, 0.0045, ..., 0.0409, -0.0305, -0.0375], [ 0.0207, -0.0659, 0.0029, ..., 0.0385, -0.0248, -0.0389], [ 0.0420, -0.0477, 0.0094, ..., 0.0488, -0.0487, -0.0332], ..., [ 0.0400, -0.0495, 0.0088, ..., 0.0478, -0.0464, -0.0337], [ 0.0231, -0.0639, 0.0036, ..., 0.0396, -0.0274, -0.0382], [ 0.0341, -0.0546, 0.0070, ..., 0.0449, -0.0398, -0.0353]], [[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[ 0.0258, -0.0616, 0.0045, ..., 0.0409, -0.0305, -0.0375], [-0.0317, -0.2758, -0.0693, ..., -0.0358, -0.2878, -0.1085], [ 0.0420, -0.0477, 0.0094, ..., 0.0488, -0.0487, -0.0332], ..., [ 0.0400, -0.0495, 0.0088, ..., 0.0478, -0.0464, -0.0337], [ 0.0231, -0.0639, 0.0036, ..., 0.0396, -0.0274, -0.0382], [ 0.0341, -0.0546, 0.0070, ..., 0.0449, -0.0398, -0.0353]]], grad_fn=<AsStridedBackward0>) In _predict function: yt shape: torch.Size([5, 10]) yt value: tensor([[0.4969, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068], [0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228], [0.5068, 0.5068, 0.5068, 0.5441, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068], [0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228, 0.0228], [0.5068, 0.4666, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068, 0.5068]], grad_fn=<AsStridedBackward0>) In _get_next_pred function: pred shape: torch.Size([5]) pred value: tensor([0.5068, 0.0228, 0.5068, 0.0228, 0.5068], grad_fn=<SumBackward1>)
发现错误:
def generate_neighbor_sets(self, adj, reverse_adj): neighbor_sets = [] concept_num = adj.shape[1] index_neighbors_sets = {} for i in range (concept_num): neighbors = set() for j in range (concept_num): if adj[i,j,0]>0 or reverse_adj[i,j,0]>0: neighbors.add(j) neighbor_sets.append(neighbors) for i,neighbors in enumerate(neighbor_sets): index_neighbors_sets[i] = neighbors return index_neighbors_sets #报错,index 5 is out of bound for dimension 0 with size 5 #没注意到adj构建的时候 adj = self.graph[masked_qt.long(),:].unsqueeze(dim=-1) reverse_adj = self.graph[:,masked_qt.long()].transpose(0,1).unsqueeze(dim=-1) #形状已经改变了,变为:[mask_num, concept_num,1] #所以搜索范围不再是10*10 #修改为: def generate_neighbor_sets(self, adj, reverse_adj): neighbor_sets = [] mask_num = adj.shape[0] concept_num = adj.shape[1] index_neighbors_sets = {} for i in range (mask_num): neighbors = set() for j in range (concept_num): if adj[i,j,0]>0 or reverse_adj[i,j,0]>0: neighbors.add(j) neighbor_sets.append(neighbors) for i,neighbors in enumerate(neighbor_sets): index_neighbors_sets[i] = neighbors return index_neighbors_sets
sgkt模型实验日志
于 2023-05-31 15:14:51 首次发布