import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
class customConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, pad=1, nbits=8, isbias=True):
super(customConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=pad, bias=isbias)
self.scale = nn.Parameter(torch.ones(out_channels))
self.scale_b = nn.Parameter(torch.ones(out_channels))
self.MAX_value = 2 ** (nbits - 1) - 1
self.out_channels = out_channels
def forward(self, x):
weight = self.conv.weight
weight_scale = weight * self.scale.view(self.out_channels, 1, 1, 1)*self.MAX_value
weight_rounded = weight_scale.round().detach() - weight_scale.detach() + weight_scale
weight_clipped = torch.clip(weight_rounded, -self.MAX_value, self.MAX_value)
bias = self.conv.bias
bias_scale = bias * self.scale_b*self.MAX_value
bias_rounded = bias_scale.round().detach() - bias_scale.detach() + bias_scale
bias_clipped = torch.clip(bias_rounded, -self.MAX_value, self.MAX_value)
out = F.conv2d(x, weight=weight_clipped, padding=self.conv.padding, bias=bias_clipped)
return out
class customLinear(nn.Module):
def __init__(self, in_features, out_features, nbits=8, isbias=True):
super(customLinear, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias=isbias)
# weight.shape = [out_channels,in_channels]
self.scale = nn.Parameter(torch.ones(out_features))
self.scale_b = nn.Parameter(torch.ones(out_features))
# tensor = torch.empty(in_features, out_features)
self.MAX_value = 2 ** (nbits - 1) - 1
self.in_features = in_features
self.out_features = out_features
def forward(self, x):
weight = self.linear.weight
weight_scale = weight * self.scale.view( self.out_features,1)*self.MAX_value
weight_rounded = weight_scale.round().detach() - weight_scale.detach() + weight_scale
weight_clipped = torch.clip(weight_rounded, -self.MAX_value, self.MAX_value)
bias = self.linear.bias
bias_scale = bias * self.scale_b*self.MAX_value
bias_rounded = bias_scale.round().detach() - bias_scale.detach() + bias_scale
bias_clipped = torch.clip(bias_rounded, -self.MAX_value, self.MAX_value)
out = F.linear(x, weight_clipped, bias=bias_clipped)
return out
# class customQuan(nn.Module):
# def __init__(self,nbits = 8):
# super(customQuan, self).__init__()
# self.MAX_value = 2 ** (nbits - 1) - 1
# self.nbits = nbits
# self.shift = 2 ** (nbits - 1)
# def forward(self, x):
# '''
# x 右移n位
# '''
# max_abs_value = torch.max(torch.abs(x))
# if max_abs_value <= self.MAX_value:
# self.shift = 1
# else:
# self.shift = torch.log2(max_abs_value/self.MAX_value).ceil()
#
# shift = (2**self.shift).detach()
# x = (x/shift).floor().detach() - (x/shift).detach() + x/shift
# x = torch.clip(x, -self.MAX_value, self.MAX_value)
# return x
#
class customQuan(nn.Module):
def __init__(self,nbits = 8):
super(customQuan, self).__init__()
self.MAX_value = 2 ** (nbits - 1) - 1
self.nbits = nbits
self.shift = 2 ** (nbits - 1)
def forward(self, x):
'''
x 右移n位
'''
max_abs_value = torch.max(torch.abs(x))
self.shift = max_abs_value/self.MAX_value
shift = self.shift.detach()
x = (x/shift).floor().detach() - (x/shift).detach() + x/shift
x = torch.clip(x, -self.MAX_value, self.MAX_value)
return x
class customfirst(nn.Module):
def __init__(self,nbits = 8):
super(customfirst, self).__init__()
self.MAX_value = 2 ** (nbits - 1) - 1
def forward(self,x):
# (0,1) -> (-127,127)
x = x - 0.5
x = x*2*127
x = x.round().detach() - x.detach() +x
x = torch.clip(x, -self.MAX_value, self.MAX_value)
return x
class SimpleCNN(nn.Module):
def __init__(self,nbits = 8):
super(SimpleCNN, self).__init__()
self.firstQ = customfirst(nbits)
self.QQ = customQuan(nbits)
self.relu = nn.ReLU()
self.conv1 = customConv(1,3,3,1,nbits=nbits)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = customConv(3, 6, 3, 1, nbits=nbits)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = customLinear(6 * 7 * 7, 64,nbits=nbits)
self.fc2 = customLinear(64, 10,nbits=nbits)
def forward(self, x):
x = self.firstQ(x)
x = self.conv1(x)
x = self.QQ(x)
x = self.relu(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.QQ(x)
x = self.relu(x)
x = self.pool2(x)
x = x.view(-1, 6 * 7 * 7)
x = self.fc1(x)
x = self.QQ(x)
x = self.relu(x)
x = self.fc2(x)
x = self.QQ(x)
x = F.softmax(x, dim=1)
return x
if __name__ == '__main__':
inputdata= torch.rand(32, 1, 28, 28)
model = SimpleCNN()
summary(model,input_data=inputdata)
y = model(inputdata)
print(y)
y.backward(torch.ones_like(y))
print(model.conv1.conv.weight.grad)
QAT_demo代码
最新推荐文章于 2024-08-09 15:02:55 发布