dataloaders中使用items()以及tqdm,加载预训练参数

1.查看dataloader类型

train_dataset.class_to_idx 从字面上来看,就是类别对应的索引

os.path.abspath(os.path.join(os.getcwd(), "../"))从当前的目录,跳到上一层目录(绝对路径)

data_root = os.path.abspath(os.path.join(os.getcwd(), "../"))  # get data root path
image_path = os.path.join(data_root, "flower_data")  # flower data set path
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                     transform=data_transform["train"])
train_num = len(train_dataset)

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

2. items()和item()还是有区别的

2.1 items()

items() 把字典 拆成(key,value)的元组形式

person={'name':'lizhong','age':'26','city':'BeiJing'}
for i in person.items():
    print(i)


结果
('name', 'lizhong')
('age', '26')
('city', 'BeiJing')


cla_dict = dict((val, key) for key, val in flower_list.items())

这句话是构成一个新的字典
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

并保存到json文件中,缩进为4   。 仍然为一个字典

在这里插入图片描述

2.2 item()

这个是把tensor中的标量取出来,就是单个数字

tensor中的元素个数为1 , 才使用 item()
多个元素时,要使用 tolist()

loss.backward()
optimizer.step()

# print statistics
running_loss += loss.item()

3.加载参数列表

numel()返回的是tensor中元素的个数(number of element)

pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}

因为 pre_weights是一个字典,存储了各个网络层的名字和参数的值
通过items()把key,value以元组的方式取出来

net.state_dict() 返回的是一个字典,k为键值,如果这个key对应的value值 的个数和 v中元素个数一致
就把这些字典保存下来,用于预训练网络的初始化
(因为输出层classifier肯定不同的!)

net = MobileNetV2(num_classes=5)

# load pretrain weights
# download url: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
model_weight_path = "./mobilenet_v2.pth"
assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
# save :GPU  load:CPU
pre_weights = torch.load(model_weight_path, map_location=device)

# delete classifier weights
# 那个网络中已经有了classfier的初始化
# imagenet分类器的初始化是用不了的
pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)

# freeze features weights ,因为命名了features的
for param in net.features.parameters():
   param.requires_grad = False

冻结所有features层的梯度,因为这里是网络的属性,自己命名的。

for param in net.features.parameters():
   param.requires_grad = False

如上,这里net.features.parameters()

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

下面还有返回参数。。。
一般strict参数都是true的。

missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)

3.1 训练

net.to(device)

# define loss function
loss_function = nn.CrossEntropyLoss()

# construct an optimizer
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)

4.tqdm

Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。

在这里插入图片描述
截图来源

train_bar = tqdm(train_loader)
for step, data in enumerate(train_bar):
    images, labels = data
    optimizer.zero_grad()
    logits = net(images.to(device))
    loss = loss_function(logits, labels.to(device))
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item()

    train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                             epochs,
                                                             loss)

结果:
在这里插入图片描述
对于条形图的颜色,可以通过如下colour参数进行控制

 net.eval()
 acc = 0.0  # accumulate accurate number / epoch
 with torch.no_grad():
     val_bar = tqdm(validate_loader, colour='green')
     for val_data in val_bar:
         val_images, val_labels = val_data
         outputs = net(val_images.to(device))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值