VQA中常用的方法是Dynamic网络,就是一个网络的输出作为另一个网络的filter权重。
fliter = nn.Conv2d(…)
# 当然这样写诗错的
filter.weight.data.fill_(network_output).
filter.forward(image)
正确写法:
import torch.nn.functional as F
p = F.conv2d(image, weight=network_output, ...)
另一种写法
fliter = nn.Conv2d(...)
filter.weight = network_output
p = filter(image)