import torch
import pandas as pd
import numpy as np
# 定义数据集合 label为Y标记
df = pd.DataFrame({
'id':[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20],
'applist':[
[1,4,3,7,8],
[5,7,32,4,68,23,4,78,23,1],
[21,23,45],
[3,2,8,6,4,7],
[6,7,8,23,56,76],
[2,3],
[1,6,3,5,2],
[9,4,12,66,90],
[34,36,37],
[4,5,67,8,2,1],
[5,2,3,6,8,1,8],
[1,2,3],
[89],
[4,3,44],
[33,45,12,34],
[55,8],
[77,54,20],
[21,45,89],
[17,13,14,10],
[2,6]
],
'label':[0,0,0,0,0,0,1,1,1,1,1,0,1,0,1,0,1,1,0,0],
})
# X变量
var_list = 'applist' #注意,这里是已经编码后的数据了,编码表我们一般取1到唯一值的个数,0用来补位,例如编码表是字母表对应1-26,那么['a','c']就对应[1,3]
# Y变量
Y = df['label'].values
# 数据处理,embedding输入的数据必须定长,这里长短不一,对不足的补0,这里0我们专门用来补位
X = df['applist'].values
# 定义一个补0的函数
def cust2mpad(x,maxlen=10): # 对不足10的补0到长度为10
pad = (np.zeros(maxlen)).tolist()
TT = []
for i,j in enumerate(x):
TT.append(j+pad[0:maxlen-len(j)])
return TT
X = cust2mpad(X,maxlen=10) #为什么长度为10,由于最长的长度是10
# 定义一个模型的结构
class model(torch.nn.Module):
def __init__(self):
super(model,self).__init__()
self.emb = torch.nn.Embedding(100,5) # 假设编码表长100,希望输出为5
self.line1 = torch.nn.Linear(5,2)
self.relu = torch.nn.ReLU()
self.line2 = torch.nn.Linear(2,1)
def forward(self,input):
x1 = self.emb(input)
x2 = x1.mean(dim=1) # 这里取平均是为了把二维list压缩为1维
x3 = self.line1(x2)
x4 = self.relu(x3)
x5 = self.line2(x4)
return x5
# 将模型实例化
mymodel = model()
optimizer = torch.optim.Adam(mymodel.parameters(),lr=0.001)
lossf = torch.nn.BCEWithLogitsLoss()
# 数据集处理
train_da = torch.tensor(X,dtype=torch.float32)
train_lab = torch.tensor(Y,dtype=torch.float32)
train_dataset = torch.utils.data.TensorDataset(train_da,train_lab)
traind = torch.utils.data.DataLoader(train_dataset,batch_size=5,shuffle=False) # 数据整理为一个个的batch,这里batch的大小为5
# 模型训练
for jj in range(5):
i = 0
for x,y0 in traind:
i = i+1
y = mymodel(x.int()) # embedding输入必须为整数
loss = lossf(y,y0.float().unsqueeze(1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(jj,i,loss.item())
if torch.isnan(loss).item():
break
# 模型使用
yy = mymodel(train_da.int())
# 获取embedding矩阵
emb = mymodel.emb.weight.data.numpy()
pytorch中使用embedding层的示例
最新推荐文章于 2024-04-02 22:32:16 发布