import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from torchvision import datasets,transforms
from torch.autograd import Variable
BATCH_SIZE = 64
train_dataset = datasets.MNIST(root='~/data/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = datasets.MNIST(root='~/data/',train=False,transform=transforms.ToTensor())
train_loader = Data.DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_loader = Data.DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE,shuffle=False)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.Max_pool = nn.MaxPool2d(2, 2, return_indices=True)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x, indices = self.Max_pool(self.conv2(x))
weight = self.conv1.weight
cam = x
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1), cam, weight, indices
model = Net()
# train_process
# loss = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(),lr = 0.001)
# num_epochs = 8
# losses = []
# acces = []
# for echo in range(num_epochs):
# train_loss = 0
# train_acc = 0
# model.train()
# for i, (data, label) in enumerate(train_loader):
# out, _, _, _ = model(data)
# lossvalue = loss(out,label)
# optimizer.zero_grad()
# lossvalue.backward()
# optimizer.step()
# train_loss += float(lossvalue)
# _,pred = out.max(1)
# num_correct = (pred == label).sum()
# acc = int(num_correct) / data.shape[0]
# train_acc += acc
# losses.append(train_loss / len(train_loader))
# acces.append(train_acc / len(train_loader))
# print("echo:"+' ' +str(echo))
# print("lose:" + ' ' + str(train_loss / len(train_loader)))
# print("accuracy:" + ' '+str(train_acc / len(train_loader)))
# torch.save(model, './model.pth')
model = torch.load('./model.pth')
one_image = train_dataset[0][0].unsqueeze(0)
_, cam_one, cam_weight, indices = model(one_image)
cam_one = nn.MaxUnpool2d(2, 2)(cam_one, indices)
print(cam_one.shape)
matrix = torch.sum(cam_one, dim=1)
print(matrix.shape)
matrix = torch.nn.UpsamplingNearest2d(size=(28, 28))(matrix.unsqueeze(0))
matrix = (matrix - matrix.min()) / (matrix.max() - matrix.min())
print(matrix.shape)
from torchvision import transforms
unloader = transforms.ToPILImage()
image = unloader(matrix.squeeze(0))
image.save('example.jpg')
image_raw = unloader(train_dataset[0][0])
image.save('raw.jpg')