粗糙的解释维度
import torch
import torch.nn as nn
y=torch.randint(0,3,(10,3))
print(y)
定义一个10*3的矩阵,里面的值随机取0,1,2
m=nn.Embedding(3,5)
print(m.weight)
展示该Embedding里面值
print(m(y))
y中的值相当于索引,在Embedding中取对应的行
print(m(y).shape)
import torch
import torch.nn as nn
y=torch.randint(0,3,(10,3))
print(y)
m=nn.Embedding(3,5)
print(m.weight)
print(m(y))
print(m(y).shape)