QAT_demo代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

class customConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, pad=1, nbits=8, isbias=True):
        super(customConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=pad, bias=isbias)
        self.scale = nn.Parameter(torch.ones(out_channels))
        self.scale_b = nn.Parameter(torch.ones(out_channels))
        self.MAX_value = 2 ** (nbits - 1) - 1
        self.out_channels = out_channels

    def forward(self, x):
        weight = self.conv.weight
        weight_scale = weight * self.scale.view(self.out_channels, 1, 1, 1)*self.MAX_value
        weight_rounded = weight_scale.round().detach() - weight_scale.detach() + weight_scale
        weight_clipped = torch.clip(weight_rounded, -self.MAX_value, self.MAX_value)

        bias = self.conv.bias
        bias_scale = bias * self.scale_b*self.MAX_value
        bias_rounded = bias_scale.round().detach() - bias_scale.detach() + bias_scale
        bias_clipped = torch.clip(bias_rounded, -self.MAX_value, self.MAX_value)

        out = F.conv2d(x, weight=weight_clipped, padding=self.conv.padding, bias=bias_clipped)
        return out

class customLinear(nn.Module):
    def __init__(self, in_features, out_features, nbits=8, isbias=True):
        super(customLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features, bias=isbias)
        # weight.shape = [out_channels,in_channels]
        self.scale = nn.Parameter(torch.ones(out_features))
        self.scale_b = nn.Parameter(torch.ones(out_features))
        # tensor = torch.empty(in_features, out_features)
        self.MAX_value = 2 ** (nbits - 1) - 1
        self.in_features = in_features
        self.out_features =  out_features
    def forward(self, x):
        weight = self.linear.weight
        weight_scale = weight * self.scale.view( self.out_features,1)*self.MAX_value
        weight_rounded = weight_scale.round().detach() - weight_scale.detach() + weight_scale
        weight_clipped = torch.clip(weight_rounded, -self.MAX_value, self.MAX_value)

        bias = self.linear.bias
        bias_scale = bias * self.scale_b*self.MAX_value
        bias_rounded = bias_scale.round().detach() - bias_scale.detach() + bias_scale
        bias_clipped = torch.clip(bias_rounded, -self.MAX_value, self.MAX_value)

        out = F.linear(x, weight_clipped, bias=bias_clipped)
        return out



# class customQuan(nn.Module):
#     def __init__(self,nbits = 8):
#         super(customQuan, self).__init__()
#         self.MAX_value = 2 ** (nbits - 1) - 1
#         self.nbits = nbits
#         self.shift = 2 ** (nbits - 1)
#     def forward(self, x):
#         '''
#         x 右移n位
#         '''
#         max_abs_value = torch.max(torch.abs(x))
#         if max_abs_value <= self.MAX_value:
#             self.shift = 1
#         else:
#             self.shift = torch.log2(max_abs_value/self.MAX_value).ceil()
#
#         shift = (2**self.shift).detach()
#         x = (x/shift).floor().detach() - (x/shift).detach() + x/shift
#         x = torch.clip(x, -self.MAX_value, self.MAX_value)
#         return x
#
class customQuan(nn.Module):
    def __init__(self,nbits = 8):
        super(customQuan, self).__init__()
        self.MAX_value = 2 ** (nbits - 1) - 1
        self.nbits = nbits
        self.shift = 2 ** (nbits - 1)
    def forward(self, x):
        '''
        x 右移n位
        '''
        max_abs_value = torch.max(torch.abs(x))
        self.shift = max_abs_value/self.MAX_value
        shift = self.shift.detach()
        x = (x/shift).floor().detach() - (x/shift).detach() + x/shift
        x = torch.clip(x, -self.MAX_value, self.MAX_value)
        return x


class customfirst(nn.Module):
    def __init__(self,nbits = 8):
        super(customfirst, self).__init__()
        self.MAX_value = 2 ** (nbits - 1) - 1
    def forward(self,x):
            # (0,1) -> (-127,127)
        x = x - 0.5
        x = x*2*127
        x = x.round().detach() - x.detach() +x
        x = torch.clip(x, -self.MAX_value, self.MAX_value)
        return x


class SimpleCNN(nn.Module):
    def __init__(self,nbits = 8):
        super(SimpleCNN, self).__init__()
        self.firstQ = customfirst(nbits)
        self.QQ = customQuan(nbits)
        self.relu = nn.ReLU()

        self.conv1 = customConv(1,3,3,1,nbits=nbits)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = customConv(3, 6, 3, 1, nbits=nbits)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = customLinear(6 * 7 * 7, 64,nbits=nbits)
        self.fc2 = customLinear(64, 10,nbits=nbits)
    def forward(self, x):
        x = self.firstQ(x)
        x = self.conv1(x)
        x = self.QQ(x)
        x = self.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x =  self.QQ(x)
        x = self.relu(x)
        x = self.pool2(x)
        x = x.view(-1, 6 * 7 * 7)
        x = self.fc1(x)
        x = self.QQ(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.QQ(x)
        x = F.softmax(x, dim=1)
        return x



if __name__ == '__main__':
    inputdata= torch.rand(32, 1, 28, 28)
    model = SimpleCNN()
    summary(model,input_data=inputdata)
    y = model(inputdata)
    print(y)
    y.backward(torch.ones_like(y))

    print(model.conv1.conv.weight.grad)
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值