pytorch错误记录(陆续更新)

这里之前想记录一些错误,结果好多错误改好后直接改下一个了,谁还记录啊,以后这里方法代码tricks和错误合集,我比较懒,不解释,尽量粘贴完整代码

tricks

1
一般将原始文件读取进行transform处理只能在cpu上,任务需要,我提取260万图片需要一天以上,等不及,所以找到了以下方法。具体参考代码和主要参考1主要参考2


import os
import torch
from torchvision.models import resnet18
import torch.nn as nn
from PIL import Image
from torchvision.transforms.functional import to_tensor, normalize
import warnings
import numpy as np
import torch
from imageio import imread
from torchvision import transforms
import tqdm
import torchvision.transforms as T
from torchvision.io import read_image

warnings.filterwarnings("ignore")
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])


transforms = torch.nn.Sequential(
    T.Resize((224, 224)),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
def extract_poster_features(videos, svae_path):
    # Loading the model
    # from models.PosterV2_7cls import pyramid_trans_expr2

    # model = pyramid_trans_expr2(img_size=224, num_classes=7).cuda()
    # print(model)

    # state_dict_path = 'checkpoint/affectnet-7-model.pth'
    # state_dict_path = "checkpoint/raf-db-model_best.pth"

    from models.fusion_model import Fusion
    model = Fusion(num_classes=8)
    state_dict_path = "/modedel.pth"

    print(f'Loading the model from {state_dict_path}.')
    state_dict = torch.load(state_dict_path, map_location='cpu')
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    model.cuda()

    for v in tqdm.tqdm(videos, desc="extract Poster V2 features"):
        name = os.path.basename(v)
        input = []
        features = []
        path = '/home/xuesongzhang/ABAWdataset/af-feature/' + name + '.csv'
        # print(path)
        images = sorted(os.listdir(v))  # 一个文件夹的所有文件
        if len(images) < 1:
            print(f"no images in {v}")
            continue
        elif len(images) > 5000:
            print(f"too many images in {v}")
            continue
        elif os.path.exists(path):
            continue
        # 读取images中的所有图片,并且resize到224*224, 然后转换成tensor, 最后进行归一化
        time_tamp = []
        print(name)
        for i in images:
            image = read_image(str(os.path.join(v, i))).to(device)
            image = transforms(image.float())
            image = image.unsqueeze(dim=0)
            time_tamp.append(i.split('.')[0])

            with torch.no_grad():
                tt = model(image)
                features.append(tt)
        features = torch.stack(features)
        features = features.view(features.size(0), -1)
        print(features.shape)
        features = features.cpu().detach().numpy()

        torch.cuda.empty_cache()

        # 将feaures保存到csv文件中,第一列为name,第二列为time_tamp 后面的列为features,并添加表头
        data = np.hstack((np.array(
            [f'[{name}]'] * features.shape[0]).reshape(-1, 1), np.array(time_tamp).reshape(-1, 1), features))
        data = np.vstack((np.array(
            ["name", "time_stamp"] + [f"feature_{i}" for i in range(features.shape[1])]), data))
        np.savetxt(os.path.join(
            svae_path, f"{name}.csv"), data, delimiter=",", fmt="%s")


if __name__ == "__main__":
    root = "/home/xmages"
    videos = os.listdir(root)
    videos = sorted(videos)
    videos = [os.path.join(root, v) for v in videos]

    # svae_path = "/data/chenyin/project/dataset/Hume/features/ResNet18-AffectNet-HQ"
    svae_path = "/"
    if not os.path.exists(svae_path):
        os.makedirs(svae_path)
    extract_poster_features(videos, svae_path)


2

查找包含某个字符的文件名,(含所有子文件)

import os

file = r'/home/xuesongzhang/ABAWdataset/cropped-align-images'

for root, dirs, files in os.walk(file):
    for file in files:
        path = os.path.join(root, file)
        re = "DS"
        if re in path:
            print(path)

报错

untimeError: size mismatch, m1: [18432 x 2048], m2: [310 x 510] at /tmp/pip-req-build-nP0IoL/aten/src/THC/generic/THCTensorMathBlas.cu:290
错误是因为大小不匹配,那怎样才算匹配呢?
m1: [a x b], m2: [c x d] b和c要相等
而不是,a x b改成c x d
感谢连接

报错

The command “dot” is required to be in your path
解决方式;很奇葩,先用apt-get install graphviz,再用pip install graphviz

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值