circleloss 资料:
https://blog.csdn.net/jacke121/article/details/106637046
训练时,把标签转为相似度,
预测时,给两个label的相似度,然后预测两个label的相似度,
import os
import torch
from torch import nn, Tensor
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from tqdm import tqdm
from circle_loss import convert_label_to_similarity, CircleLoss
def get_loader(is_train: bool, batch_size: int) -> DataLoader:
return DataLoader(
dataset=MNIST(root="./data", train=is_train, transform=ToTensor(), download=True),
batch_size=batch_size,
shuffle=is_train,
)
class Model(nn.Module):
def __init__(se