本来有个idea,想拿注意力集中机制去对抗后门攻击的(像素点攻击),没想到写完通道的注意力集中防御并没有增加,后门的基于空间注意力的也不打算测试了,就把通道注意力集中贴这里吧,因为往上没有类似代码,这份代码全是我自己写的
,一个中秋的成果泡汤了。
这份是不带注意力集中机制的
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torchvision import datasets, transforms,utils
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5],std=[0.5])])
train_data = datasets.MNIST(root = "./data/",
transform=transform,
train = True,
download = True)
test_data = datasets.MNIST(root="./data/",
transform = transform,
train = False)
train_loader = torch.utils.data.DataLoader(train_data,batch_size=64,
shuffle=True,num_workers=0)
test_loader = torch.utils.data.DataLoader(test_data,batch_size=64,
shuffle=True,num_workers=0)
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1 = nn.Conv2d(1,32,kernel_size=3,stride=1,padding=1)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1)
self.fc1 = nn.Linear(64*7*7,1024)
self.fc2 = nn.Linear(1024,512)
self.fc3 = nn.Linear(512,10)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 7* 7)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
train_accs = []
train_loss = []
test_accs = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = net.to(device)
for epoch in range(5):
running_loss = 0.0
for i,data in enumerate(train_loader,0):
inputs,labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs,labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i%100 == 99:
print('[%d,%5d] loss :%.3f' %
(epoch+1,i+1,running_loss/100))
running_loss = 0.0
train_loss.append(loss.item())
correct = 0
total = 0
_, predicted = torch.max(outputs.data, 1)
total = labels.size(0)# labels 的长度
correct = (predicted == labels).sum().item() # 预测正确的数目
train_accs.append(100*correct/total)
print('Finished Training')
PATH = './mnist_net.pth'
torch.save(net.state_dict(), PATH)
test_net = CNN()
test_net.load_state_dict(torch.load(PATH))
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = test_net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item() # 预测正确的数目
print('Accuracy of the network on the test images: %d %%' % (
100 * correct / total))
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = test_net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels)
for i in range(10):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %d : %2d %%' % (
i, 100 * class_correct[i] / class_total[i]))
print("all done")
这份是带注意力集中机制的
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torchvision import datasets, transforms,utils
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5],std=[0.5])])
train_data = datasets.MNIST(root = "./data/",
transform=transform,
train = True,
download = True)
test_data = datasets.MNIST(root="./data/",
transform = transform,
train = False)
train_loader = torch.utils.data.DataLoader(train_data,batch_size=64,
shuffle=True,num_workers=0)
test_loader = torch.utils.data.DataLoader(test_data,batch_size=64,
shuffle=True,num_workers=0)
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1 = nn.Conv2d(1,32,kernel_size=3,stride=1,padding=1)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1)
self.fc1 = nn.Linear(64*7*7,1024)
self.fc2 = nn.Linear(1024,512)
self.fc3 = nn.Linear(512,10)
self.avg_all_pool = nn.AdaptiveAvgPool2d(1)
self.SE = nn.Sequential(
nn.Linear(64, 4, bias=False),
nn.ReLU(inplace=True),
nn.Linear(4, 64, bias=False),
nn.Sigmoid()
)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
b, c, _, _ = x.size()
y = self.avg_all_pool(x).view(b, c)
y = self.SE(y).view(b, c, 1, 1)
x=x * y.expand_as(x)
x = x.view(-1, 64 * 7 * 7)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
train_accs = []
train_loss = []
test_accs = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = net.to(device)
for epoch in range(5):
running_loss = 0.0
for i,data in enumerate(train_loader,0):
inputs,labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs,labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i%100 == 99:
print('[%d,%5d] loss :%.3f' %
(epoch+1,i+1,running_loss/100))
running_loss = 0.0
train_loss.append(loss.item())
correct = 0
total = 0
_, predicted = torch.max(outputs.data, 1)
total = labels.size(0)
correct = (predicted == labels).sum().item() # 预测正确的数目
train_accs.append(100*correct/total)
print('Finished Training')
PATH = './mnist_net.pth'
torch.save(net.state_dict(), PATH)
test_net = CNN()
test_net.load_state_dict(torch.load(PATH))
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = test_net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the test images: %d %%' % (
100 * correct / total))
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = test_net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels)
for i in range(10):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %d : %2d %%' % (
i, 100 * class_correct[i] / class_total[i]))
print("all done")