算法的核心code,学习记录,仅供自己用
RNNLM
# init_state for decoder
init_state = tf.placeholder(get_default_float_type(), [None, hidden_size], name=state_name)
init_state = tf.stack([my_full_connected(init_state, hidden_size, act=tf.nn.tanh) for i in range(num_layers * 2)], axis=1)
# rnn cell
cell = self.build_cell(self.hidden_size, self.keep_prob, self.num_layers, True)
decoder = texar.modules.BasicRNNDecoder(cell=cell, vocab_size=self.nb_words)
# build decoder
outputs, final_state, _ = decoder(
decoding_strategy="train_greedy",
impute_finished=False,
helper=training_helper,
initial_state=init_state
)
transformerLM
init_state = tf.stack([my_full_connected(x, hidden_size, act=tf.nn.tanh) for i in range(state_length)], axis=1)
# transformer decoder
decoder = texar.modules.TransformerDecoder(vocab_size=nb_words, hparams=self.config_decoder)
# build decoder
outputs = decoder(inputs=self._conv_fn(input_tensor) + pos_embeeding,
memory=init_state,
memory_sequence_length=state_length,
mode='train',
decoding_strategy='train_greedy')
GPT的backbone是transformer的Encoder,这边的transformerLM的backbone是transformer的Decoder,多了一个attend到state的attention层。
graph
# main
features = nn.Embedding(2708, 1433)
agg1 = MeanAggregator(features, cuda=use_cuda)
enc1 = Encoder(features, 1433, 128, adj_lists, agg1, gcn=True, cuda=use_cuda)
agg2 = MeanAggregator(lambda nodes: enc1(nodes).t(), cuda=use_cuda)
enc2 = Encoder(lambda nodes: enc1(nodes).t(), enc1.embed_dim, 128, adj_lists, agg2,
base_model=enc1, gcn=True, cuda=use_cuda)
graphsage = SupervisedGraphSage(7, enc2)
# GraphSAGE
class SupervisedGraphSage(nn.Module):
def __init__(self, num_classes, enc):
super(SupervisedGraphSage, self).__init__()
self.enc = enc
self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim))
...
def forward(self, nodes):
embeds = self.enc(nodes)
scores = self.weight.mm(embeds)
return scores.t()
...
# Module
class MeanAggregator(nn.Module):
def __init__(self, features, cuda=False, gcn=False):
...
def forward(self, nodes, to_neighs, num_sample=10):
samp_neighs = [_set(_sample(to_neigh, num_sample)) if len(to_neigh) >= num_sample
else to_neigh for to_neigh in to_neighs]
samp_neighs = [samp_neigh + set([nodes[i]]) for i, samp_neigh in enumerate(samp_neighs)] #gcn
...
# example: samp_neighs = [{1, 2}, {0}, {2, 10}]
# unique_nodes_list = [0, 1, 2, 10]
# mask = [[0. 1/2 1/2 0.], [1. 0. 0. 0.], [0. 0. 1/2 1/2]]
embed_matrix = self.features(torch.LongTensor(unique_nodes_list)) # 求邻接节点的mean
return mask.mm(embed_matrix)
class Encoder(nn.Module):
def __init__(self, features, feature_dim,
embed_dim, adj_lists, aggregator,
num_sample=10,
base_model=None, gcn=False, cuda=False,
feature_transform=False):
...
def forward(self, nodes):
neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)] for node in nodes], self.num_sample)
# neigh_feats(batch_size,1433)
self_feats = self.features(torch.LongTensor(nodes))
combined = torch.cat([self_feats, neigh_feats], dim=1) # gcn
combined = F.relu(self.weight.mm(combined.t()))
return combined
VAE
class VAE(nn.Module):
"""Implementation of VAE(Variational Auto-Encoder)"""
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 200)
self.fc2_mu = nn.Linear(200, 10)
self.fc2_log_std = nn.Linear(200, 10)
self.fc3 = nn.Linear(10, 200)
self.fc4 = nn.Linear(200, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
mu = self.fc2_mu(h1)
log_std = self.fc2_log_std(h1)
return mu, log_std
def decode(self, z):
h3 = F.relu(self.fc3(z))
recon = torch.sigmoid(self.fc4(h3))
return recon
def reparametrize(self, mu, log_std):
std = torch.exp(log_std)
eps = torch.randn_like(std) # simple from standard normal distribution
z = mu + eps * std
return z
def forward(self, x):
mu, log_std = self.encode(x)
z = self.reparametrize(mu, log_std)
recon = self.decode(z)
return recon, mu, log_std
def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
recon_loss = F.mse_loss(recon, x, reduction="sum") #"mean" may have a bad effect on gradients
kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std))
kl_loss = torch.sum(kl_loss)
loss = recon_loss + kl_loss
return loss