问题描述
在使用GRU训练序列数据时报错
报错部位代码为:
self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers,bidirectional=self.n_directions)
解决方案:
解决方案来源:https://github.com/Sundrops/video-caption.pytorch/issues/4
在给bidirectional
传参时,将整型转为bool型。
self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers,bidirectional=bool(self.n_directions))
运行后成功解决
原因分析:
事后再来看这个报错,确实是需要传入一个bool值的,造成这样的错误可能是接口不知道什么时候又更新了?