预测时需要gap
import time
import torch
import torch.nn.functional as torch_F
import torch.nn as nn
class WeightNet(nn.Module):
r"""Applies WeightNet to a standard convolution.
The grouped fc layer directly generates the convolutional kernel,
this layer has M*inp inputs, G*oup groups and oup*inp*ksize*ksize outputs.
M/G control the amount of parameters.
"""
def __init__(self, inp, oup, ksize, stride):
super().__init__()
self.M = 2
self.G = 2
self.pad = ksize // 2
inp_gap = max(16, inp//16)
self.inp = inp
self.oup = oup
self.ksize = ksize
self.stride = stride
self.wn_fc1 = nn.Conv2d(inp