在文本处理中,CNN模型的使用。
具体可以参考博文:
https://blog.csdn.net/sunny_xsc1994/article/details/82969867
一个过滤器的结果,直接做linear多分类
import torch.nn as nn
class E2EModel(nn.Module):
def __init__(self):
super(E2EModel, self).__init__()
self.encode = BertModel.from_pretrained(model_name)
self.conv = nn.Conv1d(
in_channels=768,
out_channels=64,
kernel_size=(64,),
stride=1, # 水平和竖直方向滑动1
bias=True,
padding_mode='zeros'
)
self.pool=nn.MaxPool1d(128-64+1)
self.drop=nn.Dropout(0.5)
self.linear=nn.Linear(64,1)
# self.linear2=nn.Linear(10,1)
def forward(self, inputs_id, att_mask):
x = self.encode(inputs_id, att_mask)[0] # B*L*768
input = x.permute(0, 2, 1)
out = self.conv(input)
out=self.pool(out)
out = out.view(-1, out.size(1))#B*128
out=self.drop(out)
out=self.linear(out)
# out=self.linear2(out)
out=torch.sigmoid(out)
return out
model=E2EModel()
多个不同的过滤器,得到的结果cat在一起。
import torch.nn as nn
window_sizes=[3,5]
class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__()
self.is_training = True
self.dropout_rate = 0.5
self.embedding = BertModel.from_pretrained(model_name)
self.convs = nn.ModuleList([
nn.Sequential(nn.Conv1d(in_channels=768,
out_channels=50,
kernel_size=h),
# nn.BatchNorm1d(num_features=config.feature_size),
nn.ReLU(),
nn.MaxPool1d(kernel_size=128 - h + 1))
for h in window_sizes
])
self.fc = nn.Linear(in_features=50 * len(window_sizes),
out_features=1)
# if os.path.exists(config.embedding_path) and config.is_training and config.is_pretrain:
# print("Loading pretrain embedding...")
# self.embedding.weight.data.copy_(torch.from_numpy(np.load(config.embedding_path)))
def forward(self, inputs_id, att_mask):
embed_x = self.embedding( inputs_id, att_mask)[0]
# print('embed size 1',embed_x.size()) # 32*35*256
# batch_size x text_len x embedding_size -> batch_size x embedding_size x text_len
embed_x = embed_x.permute(0, 2, 1)
# print('embed size 2',embed_x.size()) # 32*256*35
out = [conv(embed_x) for conv in self.convs] # out[i]:batch_size x feature_size*1
# for o in out:
# print('o',o.size()) # 32*100*1
out = torch.cat(out, dim=1) # 对应第二个维度(行)拼接起来,比如说5*2*1,5*3*1的拼接变成5*5*1
# print(out.size(1)) # 32*400*1
out = out.view(-1, out.size(1))
# print(out.size()) # 32*400
# if not self.use_element:
out = F.dropout(input=out, p=self.dropout_rate)
out = self.fc(out)
out=torch.sigmoid(out)
return out
model=TextCNN()