class Params:
def __init__(self, neg_ratio, dropout, se_prop,ne,lr,reg_lambda,emb_dim, save_each,s_emb_dim,e_epoch,alp):
self.neg_ratio = neg_ratio
self.dropout = dropout
self.s_emb_dim = s_emb_dim
self.se_prop = se_prop
self.ne = ne
self.lr = lr
self.reg_lambda = reg_lambda
self.emb_dim = emb_dim
self.s_emb_dim = int(se_prop * emb_dim)
self.t_emb_dim = emb_dim - int(se_prop * emb_dim)
self.save_each = save_each
self.neg_ratio = neg_ratio
self.dropout = dropout
self.e_epoch = e_epoch
self.se_prop = se_prop
self.dim_ff = self.t_emb_dim * 2
self.n_head = 4
self.alp = alp
self.d_k = self.t_emb_dim // self.n_head
self.d_v = self.t_emb_dim // self.n_head
params = Params(
neg_ratio=5,
lr=0.001,
reg_lambda=0.0,
emb_dim=100,
dropout=0.4,
e_epoch=10,
save_each=10,
se_prop=0.36,
alp=0.5,
ne=500,
s_emb_dim=68,
)
bsize = 512
dataset = 'icews14'
dataset = Dataset(dataset, bsize)
model = RoAN_DES(dataset, params)
print(model)
RoAN_DES(
(ent_embs_h): Embedding(7128, 36)
(ent_embs_t): Embedding(7128, 36)
(rel_embs_f): Embedding(230, 100)
(rel_embs_i): Embedding(230, 100)
(m_freq_h): Embedding(7128, 64)
(m_freq_t): Embedding(7128, 64)
(d_freq_h): Embedding(7128, 64)
(d_freq_t): Embedding(7128, 64)
(y_freq_h): Embedding(7128, 64)
(y_freq_t): Embedding(7128, 64)
(m_phi_h): Embedding(7128, 64)
(m_phi_t): Embedding(7128, 64)
(d_phi_h): Embedding(7128, 64)
(d_phi_t): Embedding(7128, 64)
(y_phi_h): Embedding(7128, 64)
(y_phi_t): Embedding(7128, 64)
(m_amps_h): Embedding(7128, 64)
(m_amps_t): Embedding(7128, 64)
(d_amps_h): Embedding(7128, 64)
(d_amps_t): Embedding(7128, 64)
(y_amps_h): Embedding(7128, 64)
(y_amps_t): Embedding(7128, 64)
(Rel_emb): Rel_time_emb(
(h_map_emb): Embedding(7128, 100)
(t_map_emb): Embedding(7128, 100)
(rel_emb_h): Embedding(231, 100)
(rel_emb_t): Embedding(231, 100)
(rel_emb_q): Embedding(230, 100)
(year_emb): Embedding(1, 100)
(month_emb): Embedding(12, 100)
(day_emb): Embedding(31, 100)
(his_encoder): Encoder(
(multi): MultiHead(
(Q): Linear(in_features=100, out_features=64, bias=True)
(K): Linear(in_features=100, out_features=64, bias=True)
(V): Linear(in_features=100, out_features=64, bias=True)
(fc): Linear(in_features=64, out_features=100, bias=False)
(layernorm): LayerNorm((100,), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.4, inplace=False)
(rel_attn): ScaledDot(
(dropout): Dropout(p=0.4, inplace=False)
)
)
(ffn): FeedForward(
(fc1): Linear(in_features=100, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=100, bias=True)
(layernorm): LayerNorm((100,), eps=1e-06, elementwise_affine=True)
(dropout): Dropout(p=0.4, inplace=False)
)
)
)
)