Pytorch实现Arc Loss (实战)

下面完整代码在github仓库:传送门


一、计算余弦相似度

import torch
import math

# 两个向量的普通余弦相似度
a = torch.tensor([1, 2, 3], dtype=torch.float32)
b = torch.tensor([1, 2, 3], dtype=torch.float32)
cos_alpha = a@b / (torch.sqrt(torch.sum(torch.pow(a, 2))) * torch.sqrt(torch.sum(torch.pow(b, 2))))

print(torch.pow(a, 2))  # tensor([1., 4., 9.])
print(torch.sum(torch.pow(a, 2)))  # tensor(14.))
print(torch.sqrt(torch.sum(torch.pow(a, 2))))  # tensor(3.7417)
print(a @ b)  # tensor(14.)
print(cos_alpha)  # tensor(1.0000)  -- 相似度值
print(torch.acos(cos_alpha))  # tensor(0.0003)  -- 转弧度
print(math.degrees(torch.acos(cos_alpha)))  # 0.01978234059262607  -- 转角度
print("=============")

# 两个向量均值化后的余弦相似度
a_b = torch.cat((a, b), dim=0)  # tensor([1., 2., 3., 1., 2., 3.])
print(a_b)
min_value = torch.min(a_b)
max_value = torch.max(a_b)
print(min_value)  # tensor(1.)
print(max_value)  # tensor(3.)

mean_value = (max_value + min_value) / 2
print(mean_value)  # tensor(2.)

c = (a - mean_value)  # tensor([-1.,  0.,  1.])
d = (b - mean_value)  # tensor([-1.,  0.,  1.])
print(c)
print(d)

cos_beta = c@d / (torch.sqrt(torch.sum(torch.pow(c, 2))) * torch.sqrt(torch.sum(torch.pow(d, 2))))
print(c@d)  # tensor(2.)
print(cos_beta)  # tensor(1.0000)

cos_beta = torch.floor(cos_beta) if cos_beta >= 1 else cos_beta
print(cos_beta)  # tensor(1.)
print(torch.acos(torch.tensor(1, dtype=torch.float32)))  # tensor(0.)
print(torch.acos(cos_beta))  # tensor(0.)
print(math.degrees(torch.acos(cos_beta)))  # 0.0

二、定义arc softmax损失函数

import torch
import torch.nn as nn
import torch.nn.functional as F

class ArcNet(nn.Module):
    def __init__(self, feature_dim=2, cls_dim=10):
        super(ArcNet, self).__init__()
        # 生成一个隔离带向量,训练这个向量和原来的特征向量分开,达到增加角度的目的
        self.W = nn.Parameter(torch.randn(feature_dim, cls_dim).cuda(), requires_grad=True)
        # print(self.W.shape)  # torch.Size([2, 10])

    def forward(self, feature, m=1, s=10):
        # 对特征维度进行标准化
        x = F.normalize(feature, dim=1)
        # print(x.shape)  # torch.Size([100, 2])
        w = F.normalize(self.W, dim=0)
        # print(w.shape)  # torch.Size([2, 10])

        # s = 64 一般训练人脸的时候用到该超参
        # s = torch.sqrt(torch.sum(torch.pow(x, 2))) * torch.sqrt(torch.sum(torch.pow(w, 2)))
        # print(s)  # tensor(31.6228, device='cuda:0', grad_fn=<MulBackward0>)
        # 做L2范数化,将cosa变小,防止acosa梯度爆炸
        cosa = torch.matmul(x, w) / s
        # print(cosa.shape)  # torch.Size([100, 10])

        a = torch.acos(cosa)  # 反三角函数得出的是弧度,而非角度,1弧度=1*180/3.14=50度
        # print(a)  # torch.Size([100, 10])

        arcsoftmax = torch.exp(
            s * torch.cos(a + m)) / (torch.sum(torch.exp(s * cosa), dim=1, keepdim=True)
                                     - torch.exp(s * cosa) + torch.exp(s * torch.cos(a + m)))
        # print(arcsoftmax)
        # print(arcsoftmax.shape)  # torch.Size([100, 10])
        '''这里对e的指数cos(a+m)再乘回来, 让指数函数的输出更大,从而使得arcsoftmax输出更小,
        即log_arcsoftmax输出更大。
            这里argsoftmax的概率不为1,小于1,这会导致交叉熵损失看起来很大,且最优点损失也很大。
            将arcsoftmax放在输出层去训练,就变成一个网络去训练
        '''
        # print(torch.sum(arcsoftmax, dim=1))

        # AM_softmax = torch.exp(
        #     s * (torch.cos(a) - m) / (torch.sum(torch.exp(s * cosa), dim=1, keepdim=True)
        #                               - torch.exp(s * cosa) + torch.exp(s * (torch.cos(a) - m)))
        #
        # )

        return arcsoftmax


if __name__ == '__main__':
    arc = ArcNet(feature_dim=2, cls_dim=10)

    feature = torch.randn(100, 2).cuda()
    out = arc(feature)
    # print(feature)  # 原来的特征数据

三、搭建网络模型

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from Arcsoftmax import ArcNet

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_layer = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1, 2),  # 28*28
            nn.BatchNorm2d(32),
            nn.PReLU(),
            nn.Conv2d(32, 32, 5, 1, 2),  # 28*28
            nn.BatchNorm2d(32),
            nn.PReLU(),
            nn.MaxPool2d(2, 2),  # 14*14

            nn.Conv2d(32, 64, 5, 1, 2),  # 14*14
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(64, 64, 5, 1, 2),  # 14*14
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.MaxPool2d(2, 2),  # 7*7

            nn.Conv2d(64, 128, 5, 1, 2),  # 7*7
            nn.BatchNorm2d(128),
            nn.PReLU(),
            nn.Conv2d(128, 128, 5, 1, 2),  # 7*7
            nn.BatchNorm2d(128),
            nn.PReLU(),
            nn.MaxPool2d(2, 2)  # 3*3

        )
        self.feature = nn.Linear(128*3*3, 2)
        # self.output = nn.Linear(2, 10)
        self.arcsoftmax = ArcNet(2, 10)

    def forward(self, x):
        y_conv = self.conv_layer(x)
        y_conv = torch.reshape(y_conv, [-1, 128*3*3])
        y_feature = self.feature(y_conv)
        # print(y_feature.shape)  # torch.Size([100, 2])

        # 在训练的时候,同时训练了Net_model的参数,也训练了Arcsoftmax的参数
        y_output = torch.log(self.arcsoftmax(y_feature))
        # print(y_output.shape)  # torch.Size([100, 10])

        return y_feature, y_output

    def visualize(self, feat, labels, epoch):
        # plt.ion()
        color = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
                 '#ff00ff', '#990000', '#999900', '#009900', '#009999']
        plt.clf()
        for i in range(10):
            plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=color[i])
        plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')
        # plt.xlim(xmin=-5, xmax=5)
        # plt.ylim(ymin=-5, ymax=5)
        plt.title("epoch=%d" % epoch)
        plt.savefig('./images/epoch=%d.jpg' % epoch)
        # plt.draw()
        # plt.pause(0.001)

    def visualize2(self, feat, labels, epoch):

        color = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
                 '#ff00ff', '#990000', '#999900', '#009900', '#009999']
        plt.clf()

        for i in range(10):
            plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=color[i])
        plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')

        plt.title("epoch=%d" % epoch)
        plt.savefig('./images2/epoch=%d.jpg' % epoch)



if __name__ == '__main__':
    net = Net().cuda()
    a = torch.randn(100, 1, 28, 28).cuda()
    net(a)

四、开始训练数据

import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import r2_score

from Net_Model import Net
import os
import numpy as np

if __name__ == '__main__':
    save_path = "./models/net_arcloss.pth"
    train_data = torchvision.datasets.MNIST(root="./MNIST", download=True, train=True,
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize(mean=[0.5, ], std=[0.5, ])
                                            ]))
    test_data = torchvision.datasets.MNIST(root="./MNIST", download=True, train=False,
                                           transform=transforms.Compose([
                                               transforms.ToTensor(),
                                               transforms.Normalize(mean=[0.5, ], std=[0.5,])
                                           ]))
    train_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=100, num_workers=4)
    test_loader = data.DataLoader(dataset=test_data, shuffle=True, batch_size=100, num_workers=2)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = Net().to(device)

    if os.path.exists(save_path):
        net.load_state_dict(torch.load(save_path))
    else:
        print("No Param")

    'CrossEntropyLoss()=torch.log(torch.softmax(None))+nn.NLLLoss()'
    'CrossEntropyLoss()=log_softmax() + NLLLoss() '
    'nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合'

    loss_fn = nn.NLLLoss()
    # optimizer = torch.optim.Adam(net.parameters())
    # optimizer = torch.optim.SGD(net.parameters(),lr=1e-3, momentum=0.9)
    optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)
    # optimizer = torch.optim.SGD(net.parameters(),lr=1e-3, momentum=0.9, weight_decay=0.0005)

    epoch = 0
    while True:
        feat_loader = []
        label_loader = []
        for i, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)
            feature, output = net.forward(x)
            # print(feature.shape)  # torch.Size([100, 2])
            # print(output.shape)  # torch.Size([100, 10])

            loss = loss_fn(output, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # print(y.shape)  # torch.Size([100])
            feat_loader.append(feature)
            label_loader.append(y)

            if i % 100 == 0:
                print("epoch:", epoch, "i:", i, "arcsoftmax_loss:", loss.item())

        feat = torch.cat(feat_loader, 0)
        labels = torch.cat(label_loader, 0)
        net.visualize(feat.data.cpu().numpy(), labels.data.cpu().numpy(), epoch)

        torch.save(net.state_dict(), save_path)

        with torch.no_grad():
            feat_loader2 = []
            label_loader2 = []
            label_list = []
            output_list = []
            for i, (x, y) in enumerate(test_loader):  # 加验证集
                x = x.to(device)
                y = y.to(device)
                feature, output = net.forward(x)

                loss = loss_fn(output, y)

                feat_loader2.append(feature)  # 方便画图
                label_loader2.append(y)

                output = torch.argmax(output, 1)
                # print(output.shape)  # torch.Size([100])
                # print(y.shape)  # torch.Size([100])

                label_list.append(y.data.cpu().numpy().reshape(-1))  # 方便做r2_score
                output_list.append(output.data.cpu().numpy().reshape(-1))

                if i % 600 == 0:
                    print("epoch:", epoch, "i:", i, "validate_loss:", loss.item())

            feat2 = torch.cat(feat_loader2, 0)
            labels2 = torch.cat(label_loader2, 0)
            net.visualize2(feat2.data.cpu().numpy(), labels2.data.cpu().numpy(), epoch)

            # r2_score评估

            r2 = r2_score(label_list, output_list)
            print("验证集第{}轮, r2_score评估分类精度为:{}".format(epoch, r2))

        epoch += 1
        if epoch == 30:
            break
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值