代码 batch_size, n_features = w.shape emb_out = self.embeddings.index_select(0, w.flatten()) x = emb_out.view(batch_size, -1) # -1 means: n_features * embed_size 也就是 x = self.embeddings(w)