import torch.nn as nn
import torch
in_dims=3
n_hid=2
ws = torch.nn.ModuleList()
ws.append(nn.Linear(in_dims, n_hid))
print(ws[0])
x = torch.Tensor([[1, 2, 3],
[4, 5, 6]])
print(ws[0](x))
import torch.nn as nn
import torch
in_dims=3
n_hid=2
ws = torch.nn.ModuleList()
ws.append(nn.Linear(in_dims, n_hid))
print(ws[0])
x = torch.Tensor([[1, 2, 3],
[4, 5, 6]])
print(ws[0](x))