import torch
from torch import nn
class Model(nn.Module):
def __init__(self,
dict_len=5000,
embedding_dim=512):
"""
假定:
- 字典长度:5000
- 特征长度:512
"""
super(Model, self).__init__()
self.embed = nn.Embedding(num_embeddings=dict_len,
embedding_dim=embedding_dim)
self.gru = nn.GRU(input_size=embedding_dim,
hidden_size=embedding_dim,
num_layers=1,
batch_first=False)
self.linear = nn.Linear(in_features=512, out_features=2)
代码练习系列(三)——搭建一个可用于垃圾短信分类的循环神经网络模型
于 2024-07-15 18:36:04 首次发布