我自己在研究BNN,苦于找不到代码(没有一个人写一个吗???)
自己写了一个,也不知道是不是标准的BNN,反正就是把所有参数都二值化了,用的MNIST,效果一般。如果只二值权值的话就需要分别对每一层进行二值,这个代码量挺大的而且没有意义。
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
class Net(nn.Module): #class 子类(父类)
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.hardtanh(self.conv1(x))
x = F.max_pool2d(x, kernel_size = 2, stride = 2)
x = F.hardtanh(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.hardtanh(self.fc1(x))
x = self.fc2(x)
# return x
return F.log_softmax(x, dim=1)
def paraMod(wb, model):
start = 0
for i,(name, para) in enumerate(model.named_parameters()):
size = np.prod(para.shape)
end = start+size
para.data = wb[start:end].reshape(para.shape).cuda() #把参数传进model里,para是浅拷贝
start = end
def train(model, data, target, optimizer, p1, epoch):
model.train()
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
paraMod(p1, model)
optimizer.step()
print('epoch {} Train Loss: {:.6f}'.format(epoch, loss.item()))
p2 = getParameter(model)
return p2
def getParameter(model):
p1 = []
for i, (name,para) in enumerate(model.named_parameters()):
p1.extend(para.cpu().data.numpy().reshape(-1,))
return torch.tensor(p1)
if __name__ == '__main__':
batch_size = 64
test_batch_size = 1000
device = torch.device("cuda")
kwargs = {'num_workers': 0, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./Dataset', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True, **kwargs)
# test_loader = torch.utils.data.DataLoader(
# datasets.MNIST('./Dataset', train=False, download=True,
# transform=transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize((0.1307,), (0.3081,))
# ])),
# batch_size=test_batch_size, shuffle=True, **kwargs)
for index, (data, target) in enumerate(train_loader): #只提取第一组数据64张图片保证后续对比
data, target = data.to(device), target.to(device)
data = torch.sign(data) # 将输入二值化
break #可以去掉这个,然后改成对每一batch数据进行训练
model = Net().to(device)
p1 = getParameter(model) #实值参数
L = len(p1)
wb = torch.sign(p1) #二值参数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.1)
for epoch in range(1, 101):
paraMod(wb,model)
p1 = train(model, data, target, optimizer, p1, epoch)
wb = torch.sign(p1)