这是一份基于 PyTorch 的 SRU (Simple Recurrent Unit) 网络代码:
import torch
import torch.nn as nn
class SRU(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0):
super(SRU, self).__init__()
self.input_size = input_size
self.hidden_size = hidd