特征层面上的生成对抗,market特征 迁移 duke

import torch
import torch.nn as nn
from torchvision import models
from torchvision import datasets,transforms
#import matplotlib.pyplot as plt



class feature_extract_net(nn.Module):
    def __init__(self):
        super(feature_extract_net,self).__init__()
        self.model = models.resnet50(pretrained=False)
        self.model.load_state_dict(torch.load('./model/resnet50.pth'))
        self.model.avgpool = nn.AdaptiveAvgPool2d((1,1))

    def forward(self,x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = x.view(x.size(0), x.size(1))

        return x

class generator(nn.Module):
    def __init__(self):
        super(generator,self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(2048,2048),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(2048),

            nn.Linear(2048,2048),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(2048),

            nn.ReLU(),
        )

    def forward(self,x):
        x = self.gen(x)
        return x

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()

        self.dis = nn.Sequential(
            nn.Linear(2048,1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),

            nn.Linear(1024,256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),

            # nn.Linear(256,128),
            # nn.LeakyReLU(0.2),
            # nn.BatchNorm1d(128),

            nn.Linear(256,1),
            nn.Sigmoid()
        )

    def forward(self,x):
        x=self.dis(x)
        return x

#loss_f = torch.nn.BCEWithLogitsLoss()
# loss_f = torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()

transform_list = [
    transforms.Resize((256, 128), interpolation=3),
    transforms.Pad(10),
    transforms.RandomCrop((256, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

trans = transforms.Compose(transform_list)

image_dataset1 = datasets.ImageFolder('./Market/pytorch/train', trans)
image_dataset2 = datasets.ImageFolder('./duke/pytorch/train', trans)
try_dataset1 = datasets.ImageFolder('./try_data1', trans)
try_dataset2 = datasets.ImageFolder('./try_data2', trans)


dataloader1 = torch.utils.data.DataLoader(image_dataset1, batch_size=16, shuffle=True)
dataloader2 = torch.utils.data.DataLoader(image_dataset2, batch_size=16, shuffle=True)
try_dataloader1 = torch.utils.data.DataLoader(try_dataset1, batch_size=2)
try_dataloader2 = torch.utils.data.DataLoader(try_dataset2, batch_size=2)

net = feature_extract_net()
net.cuda()
net.eval()

gen = generator()
gen.cuda()


dis = discriminator()
dis.cuda()

optimizer_dis = torch.optim.Adam(dis.parameters(),lr=0.0002,betas=(0.5,0.999))
optimizer_gen = torch.optim.Adam(gen.parameters(),lr=0.0002,betas=(0.5,0.999))

batches_track = []
g_loss_list = []
d_loss_list = []
batches_done = 0

epoch_n=1
for epoch in range(epoch_n):
    for  data_group in zip(dataloader1,dataloader2):
        data1 = data_group[0]
        data2 = data_group[1]

        inputs1,labels1 = data1
        inputs2,labels2 = data2

        inputs1 = inputs1.cuda()
        inputs2 = inputs2.cuda()

        batch_s1 = inputs1.size()[0]
        batch_s2 = inputs2.size()[0]
        if batch_s1 != batch_s2:
            break

        inputs1 = net(inputs1)
        inputs2 = net(inputs2)

        valid = torch.Tensor(inputs1.shape[0],1).fill_(1.0)
        valid = valid.cuda()
        fake = torch.Tensor(inputs1.shape[0],1).fill_(0.0)
        fake = fake.cuda()

        fea_gen = gen(inputs1)


        gen_pre = dis(fea_gen)
        real_pre = dis(inputs2)

        g_loss = adversarial_loss(gen_pre,valid)

        optimizer_gen.zero_grad()
        g_loss.backward()
        optimizer_gen.step()

        optimizer_dis.zero_grad()
        real_loss = adversarial_loss( dis(inputs2) ,valid)
        fake_loss = adversarial_loss(dis(fea_gen.detach()),fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_dis.step()

        batches_done = batches_done + 1
        batches_track.append(batches_done)
        g_loss_list.append(g_loss.item())
        d_loss_list.append(d_loss.item())

        print('Epoch {}/{},batches_done {}'.format(epoch+1,epoch_n,batches_done))
        print("discriminator Loss:{:.4f}".format(d_loss.item()))
        print('generator loss:{:.4f}..'.format(g_loss.item()))

torch.save(gen.cpu().state_dict(), 'try_net.pth')

# fig,ax = plt.subplots()
# ax.plot(batches_track,d_loss_list,label='discriminator loss')
# ax.plot(batches_track,g_loss_list,label='generator loss')
# ax.legend()
# plt.show()





 

    正常的生成对抗网络,都是生成图片层级的内容,这个部分的代码是将Market作为source,duke作为target,在特征的层面上进行生成对抗。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值