0019-pytorch-celeba人脸属性分类模板代码

# -*- coding: utf-8 -*-
"""
Created on 2019/10/14 0:11
@author: Johnson
Email:593956670@qq.com
@software: PyCharm
"""



###########################参考链接###########################
# https://blog.csdn.net/think_three/article/details/88930203

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torchvision.transforms as transforms
import torch.utils.data as Data
from PIL import Image
import numpy as np
import os

img_root = 'W:/pic_data/face/CelebA/Img/img_align_celeba'
train_txt = 'W:/pic_data/face/CelebA/Img/train10000.txt'
batch_size = 2


def default_loader(path):
    try:
        img = Image.open(path)
        return img.convert('RGB')
    except:
        print("Can not open {0}".format(path))


class myDataset(Data.DataLoader):
    def __init__(self, img_dir, img_txt, transform=None, loader=default_loader):
        img_list = []
        img_labels = []

        fp = open(img_txt, 'r')
        for line in fp.readlines():
            if len(line.split()) != 41:
                continue
            img_list.append(line.split()[0])
            img_label_single = []
            for value in line.split()[1:]:
                if value == '-1':
                    img_label_single.append(0)
                if value == '1':
                    img_label_single.append(1)
            img_labels.append(img_label_single)
        self.imgs = [os.path.join(img_dir, file) for file in img_list]
        self.labels = img_labels
        self.transform = transform
        self.loader = loader

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        img_path = self.imgs[index]
        label = torch.from_numpy(np.array(self.labels[index], dtype=np.int64))
        img = self.loader(img_path)
        if self.transform is not None:
            try:
                img = self.transform(img)
            except:
                print('Cannot transform image: {}'.format(img_path))
        return img, label


transform = transforms.Compose([
    transforms.Resize(40),
    transforms.CenterCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

train_dataset = myDataset(img_dir=img_root, img_txt=train_txt, transform=transform)
train_dataloader = Data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


# print(len(train_dataset))
# print(len(train_dataloader))
def make_conv():
    return nn.Sequential(
        nn.Conv2d(3, 16, 3, 1, 1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(16, 32, 3, 1, 1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 3, 1, 1),
        nn.ReLU(),
        # nn.Dropout(0.5),
        nn.MaxPool2d(2)
    )


def make_fc():
    return nn.Sequential(
        nn.Linear(64 * 4 * 4, 128),
        nn.ReLU(),
        # nn.Dropout(0.5),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(64, 2)
    )


class face_attr(nn.Module):
    def __init__(self):
        super(face_attr, self).__init__()
        # attr0
        self.attr0_layer1 = make_conv()
        self.attr0_layer2 = make_fc()
        # attr1
        self.attr1_layer1 = make_conv()
        self.attr1_layer2 = make_fc()
        # attr2
        self.attr2_layer1 = make_conv()
        self.attr2_layer2 = make_fc()
        # attr3
        self.attr3_layer1 = make_conv()
        self.attr3_layer2 = make_fc()
        # attr4
        self.attr4_layer1 = make_conv()
        self.attr4_layer2 = make_fc()
        # attr5
        self.attr5_layer1 = make_conv()
        self.attr5_layer2 = make_fc()
        # attr6
        self.attr6_layer1 = make_conv()
        self.attr6_layer2 = make_fc()
        # attr7
        self.attr7_layer1 = make_conv()
        self.attr7_layer2 = make_fc()
        # attr8
        self.attr8_layer1 = make_conv()
        self.attr8_layer2 = make_fc()
        # attr9
        self.attr9_layer1 = make_conv()
        self.attr9_layer2 = make_fc()
        # attr10
        self.attr10_layer1 = make_conv()
        self.attr10_layer2 = make_fc()
        # attr11
        self.attr11_layer1 = make_conv()
        self.attr11_layer2 = make_fc()
        # attr12
        self.attr12_layer1 = make_conv()
        self.attr12_layer2 = make_fc()
        # attr13
        self.attr13_layer1 = make_conv()
        self.attr13_layer2 = make_fc()
        # attr14
        self.attr14_layer1 = make_conv()
        self.attr14_layer2 = make_fc()
        # attr15
        self.attr15_layer1 = make_conv()
        self.attr15_layer2 = make_fc()
        # attr16
        self.attr16_layer1 = make_conv()
        self.attr16_layer2 = make_fc()
        # attr17
        self.attr17_layer1 = make_conv()
        self.attr17_layer2 = make_fc()
        # attr18
        self.attr18_layer1 = make_conv()
        self.attr18_layer2 = make_fc()
        # attr19
        self.attr19_layer1 = make_conv()
        self.attr19_layer2 = make_fc()
        # attr20
        self.attr20_layer1 = make_conv()
        self.attr20_layer2 = make_fc()
        # attr21
        self.attr21_layer1 = make_conv()
        self.attr21_layer2 = make_fc()
        # attr22
        self.attr22_layer1 = make_conv()
        self.attr22_layer2 = make_fc()
        # attr23
        self.attr23_layer1 = make_conv()
        self.attr23_layer2 = make_fc()
        # attr24
        self.attr24_layer1 = make_conv()
        self.attr24_layer2 = make_fc()
        # attr25
        self.attr25_layer1 = make_conv()
        self.attr25_layer2 = make_fc()
        # attr26
        self.attr26_layer1 = make_conv()
        self.attr26_layer2 = make_fc()
        # attr27
        self.attr27_layer1 = make_conv()
        self.attr27_layer2 = make_fc()
        # attr28
        self.attr28_layer1 = make_conv()
        self.attr28_layer2 = make_fc()
        # attr29
        self.attr29_layer1 = make_conv()
        self.attr29_layer2 = make_fc()
        # attr30
        self.attr30_layer1 = make_conv()
        self.attr30_layer2 = make_fc()
        # attr31
        self.attr31_layer1 = make_conv()
        self.attr31_layer2 = make_fc()
        # attr32
        self.attr32_layer1 = make_conv()
        self.attr32_layer2 = make_fc()
        # attr33
        self.attr33_layer1 = make_conv()
        self.attr33_layer2 = make_fc()
        # attr34
        self.attr34_layer1 = make_conv()
        self.attr34_layer2 = make_fc()
        # attr35
        self.attr35_layer1 = make_conv()
        self.attr35_layer2 = make_fc()
        # attr36
        self.attr36_layer1 = make_conv()
        self.attr36_layer2 = make_fc()
        # attr37
        self.attr37_layer1 = make_conv()
        self.attr37_layer2 = make_fc()
        # attr38
        self.attr38_layer1 = make_conv()
        self.attr38_layer2 = make_fc()
        # attr39
        self.attr39_layer1 = make_conv()
        self.attr39_layer2 = make_fc()

    def forward(self, x):
        out_list = []
        # out0
        out0 = self.attr0_layer1(x)
        out0 = out0.view(out0.size(0), -1)
        out0 = self.attr0_layer2(out0)
        out_list.append(out0)
        # out1
        out1 = self.attr1_layer1(x)
        out1 = out1.view(out1.size(0), -1)
        out1 = self.attr1_layer2(out1)
        out_list.append(out1)
        # out2
        out2 = self.attr2_layer1(x)
        out2 = out2.view(out2.size(0), -1)
        out2 = self.attr2_layer2(out2)
        out_list.append(out2)
        # out3
        out3 = self.attr3_layer1(x)
        out3 = out3.view(out3.size(0), -1)
        out3 = self.attr3_layer2(out3)
        out_list.append(out3)
        # out4
        out4 = self.attr4_layer1(x)
        out4 = out4.view(out4.size(0), -1)
        out4 = self.attr4_layer2(out4)
        out_list.append(out4)
        # out5
        out5 = self.attr5_layer1(x)
        out5 = out5.view(out5.size(0), -1)
        out5 = self.attr5_layer2(out5)
        out_list.append(out5)
        # out6
        out6 = self.attr6_layer1(x)
        out6 = out6.view(out6.size(0), -1)
        out6 = self.attr6_layer2(out6)
        out_list.append(out6)
        # out7
        out7 = self.attr7_layer1(x)
        out7 = out7.view(out7.size(0), -1)
        out7 = self.attr7_layer2(out7)
        out_list.append(out7)
        # out8
        out8 = self.attr8_layer1(x)
        out8 = out8.view(out8.size(0), -1)
        out8 = self.attr8_layer2(out8)
        out_list.append(out8)
        # out9
        out9 = self.attr9_layer1(x)
        out9 = out9.view(out9.size(0), -1)
        out9 = self.attr9_layer2(out9)
        out_list.append(out9)
        # out10
        out10 = self.attr10_layer1(x)
        out10 = out10.view(out10.size(0), -1)
        out10 = self.attr10_layer2(out10)
        out_list.append(out10)
        # out11
        out11 = self.attr11_layer1(x)
        out11 = out11.view(out11.size(0), -1)
        out11 = self.attr11_layer2(out11)
        out_list.append(out11)
        # out12
        out12 = self.attr12_layer1(x)
        out12 = out12.view(out12.size(0), -1)
        out12 = self.attr12_layer2(out12)
        out_list.append(out12)
        # out13
        out13 = self.attr13_layer1(x)
        out13 = out13.view(out13.size(0), -1)
        out13 = self.attr13_layer2(out13)
        out_list.append(out13)
        # out14
        out14 = self.attr14_layer1(x)
        out14 = out14.view(out14.size(0), -1)
        out14 = self.attr14_layer2(out14)
        out_list.append(out14)
        # out15
        out15 = self.attr15_layer1(x)
        out15 = out15.view(out15.size(0), -1)
        out15 = self.attr15_layer2(out15)
        out_list.append(out15)
        # out16
        out16 = self.attr16_layer1(x)
        out16 = out16.view(out16.size(0), -1)
        out16 = self.attr16_layer2(out16)
        out_list.append(out16)
        # out17
        out17 = self.attr17_layer1(x)
        out17 = out17.view(out17.size(0), -1)
        out17 = self.attr17_layer2(out17)
        out_list.append(out17)
        # out18
        out18 = self.attr18_layer1(x)
        out18 = out18.view(out18.size(0), -1)
        out18 = self.attr18_layer2(out18)
        out_list.append(out18)
        # out19
        out19 = self.attr19_layer1(x)
        out19 = out19.view(out19.size(0), -1)
        out19 = self.attr19_layer2(out19)
        out_list.append(out19)
        # out20
        out20 = self.attr20_layer1(x)
        out20 = out20.view(out20.size(0), -1)
        out20 = self.attr20_layer2(out20)
        out_list.append(out20)
        # out21
        out21 = self.attr21_layer1(x)
        out21 = out21.view(out21.size(0), -1)
        out21 = self.attr21_layer2(out21)
        out_list.append(out21)
        # out22
        out22 = self.attr22_layer1(x)
        out22 = out22.view(out22.size(0), -1)
        out22 = self.attr22_layer2(out22)
        out_list.append(out22)
        # out23
        out23 = self.attr23_layer1(x)
        out23 = out23.view(out23.size(0), -1)
        out23 = self.attr23_layer2(out23)
        out_list.append(out23)
        # out24
        out24 = self.attr24_layer1(x)
        out24 = out24.view(out24.size(0), -1)
        out24 = self.attr24_layer2(out24)
        out_list.append(out24)
        # out25
        out25 = self.attr25_layer1(x)
        out25 = out25.view(out25.size(0), -1)
        out25 = self.attr25_layer2(out25)
        out_list.append(out25)
        # out26
        out26 = self.attr26_layer1(x)
        out26 = out26.view(out26.size(0), -1)
        out26 = self.attr26_layer2(out26)
        out_list.append(out26)
        # out27
        out27 = self.attr27_layer1(x)
        out27 = out27.view(out27.size(0), -1)
        out27 = self.attr27_layer2(out27)
        out_list.append(out27)
        # out28
        out28 = self.attr28_layer1(x)
        out28 = out28.view(out28.size(0), -1)
        out28 = self.attr28_layer2(out28)
        out_list.append(out28)
        # out29
        out29 = self.attr29_layer1(x)
        out29 = out29.view(out29.size(0), -1)
        out29 = self.attr29_layer2(out29)
        out_list.append(out29)
        # out30
        out30 = self.attr30_layer1(x)
        out30 = out30.view(out30.size(0), -1)
        out30 = self.attr30_layer2(out30)
        out_list.append(out30)
        # out31
        out31 = self.attr31_layer1(x)
        out31 = out31.view(out31.size(0), -1)
        out31 = self.attr31_layer2(out31)
        out_list.append(out31)
        # out32
        out32 = self.attr32_layer1(x)
        out32 = out32.view(out32.size(0), -1)
        out32 = self.attr32_layer2(out32)
        out_list.append(out32)
        # out33
        out33 = self.attr33_layer1(x)
        out33 = out33.view(out33.size(0), -1)
        out33 = self.attr33_layer2(out33)
        out_list.append(out33)
        # out34
        out34 = self.attr34_layer1(x)
        out34 = out34.view(out34.size(0), -1)
        out34 = self.attr34_layer2(out34)
        out_list.append(out34)
        # out35
        out35 = self.attr35_layer1(x)
        out35 = out35.view(out35.size(0), -1)
        out35 = self.attr35_layer2(out35)
        out_list.append(out35)
        # out36
        out36 = self.attr36_layer1(x)
        out36 = out36.view(out36.size(0), -1)
        out36 = self.attr36_layer2(out36)
        out_list.append(out36)
        # out37
        out37 = self.attr37_layer1(x)
        out37 = out37.view(out37.size(0), -1)
        out37 = self.attr37_layer2(out37)
        out_list.append(out37)
        # out38
        out38 = self.attr38_layer1(x)
        out38 = out38.view(out38.size(0), -1)
        out38 = self.attr38_layer2(out38)
        out_list.append(out38)
        # out39
        out39 = self.attr39_layer1(x)
        out39 = out39.view(out39.size(0), -1)
        out39 = self.attr39_layer2(out39)
        out_list.append(out39)

        return out_list


module = face_attr()
# print(module)


optimizer = optim.Adam(module.parameters(), lr=0.001, weight_decay=1e-8)

loss_list = []
for i in range(40):
    loss_func = nn.CrossEntropyLoss()
    loss_list.append(loss_func)
# loss_func = nn.CrossEntropyLoss()
for Epoch in range(50):
    all_correct_num = 0
    for ii, (img, label) in enumerate(train_dataloader):

        img = Variable(img)
        label = Variable(label)
        #    optimizer.zero_grad()
        output = module(img)
        optimizer.zero_grad()
        for i in range(40):
            loss = loss_list[i](output[i], label[:, i])
            loss.backward()
            _, predict = torch.max(output[i], 1)
            correct_num = sum(predict == label[:, i])
            all_correct_num += correct_num.data[0]
        optimizer.step()

    Accuracy = all_correct_num * 1.0 / (len(train_dataset) * 40.0)
    print('Epoch ={0},all_correct_num={1},Accuracy={2}'.format(Epoch, all_correct_num, Accuracy))

    torch.save(module, 'W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')
'''
test_txt = 'W:/pic_data/face/CelebA/Img/test1000.txt'
test_dataset = myDataset(img_dir=img_root,img_txt = test_txt,transform= transform)
test_dataloader = Data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)
module = torch.load('W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')
module.eval()
all_correct_num = 0
for ii,(img,label) in enumerate(test_dataloader):

    img = Variable(img)
    label = Variable(label)
    output = module(img)
    for i in range(40):
        _,predict = torch.max(output[i],1)
        correct_num = sum(predict==label[:,i])
        all_correct_num += correct_num.data[0]          
Accuracy =  all_correct_num *1.0/(len(test_dataset)*40.0)
print('all_correct_num={0},Accuracy={1}'.format(all_correct_num,Accuracy))
'''
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值