【无标题】parseq

文章讲述了如何在不同CUDA版本下安装和配置PyTorch库,包括CUDA11.3和10.2,以及如何使用LMDB格式进行数据集转换和处理。还涉及了使用`read.py`和`test.py`等脚本的问题解决和数据预处理过程。
摘要由CSDN通过智能技术生成

一堆乱七八糟

conda create -n parseq python=3.9 -y
conda activate parseq
# CUDA 11.3
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge

# CUDA 10.2
pip install torch==1.10.1+cu102 torchvision==0.11.2+cu102 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu102/torch_stable.html

# Use specific platform build. Other PyTorch 1.13 options: cu116, cu117, rocm5.2
platform=CUDA 10.2
# Generate requirements files for specified PyTorch platform
make torch-CUDA 10.2

Install the project and core + train + test dependencies. Subsets: [train,test,bench,tune]

pip install -r requirements/core.cpu.txt -e .[train,test]

pip install -r requirements.txt 太难了就放弃吧,缺少module再注意pip就好,如果安装版本报错,可以参考一下 requirements.txt 指明的版本
一般都会有个 pip install -e .安装项目本身

./test.py outputs///checkpoints/last.ckpt # or use the released weights: ./test.py pretrained=parseq

python test.py work_dirs/parseq-bb5792a6.pt

python read.py work_dirs/parseq-bb5792a6.pt refine_iters:int=2 decode_ar:bool=false --images demo_images/*

做不下去了真的TTTTTTTT哭死
train 和test都运行不了

存档

TypeError: init() missing 21 required positional arguments: [When running test.py and read.py]
环境按照这个装的,只教了怎么用自己自制的数据集,类型是lmdb,tools里面有转化器

python read.py work_dirs/parseq-bb5792a6.pt refine_iters:int=2 decode_ar:bool=false --images demo_images/*

指定某张显卡,

CUDA_VISIBLE_DEVICES=1 python read.py pretrained=parseq refine_iters:int=2 decode_ar:bool=false --images demo_images/*
python read.py pretrained=parseq --images demo_images/*

python read.py work_dirs/parseq-bb5792a6.pt --images demo_images/* 报错

cp -r “DeepSolo-main/datasets/ocr_en_422k/” parseq-main/demo_images/ocr_en
cp -r “DeepSolo-main/datasets/ocr_zh_230920_381k/” parseq-main/demo_images/ocr_zh

python read.py pretrained=parseq --images demo_images/ocr_en_422k/* --output outputs/ocr_en

python read.py pretrained=parseq --images demo_images/ocr_zh_230920_381k/* --output outputs/ocr_zh
识别不到中文

lmdb格式数据集转换代码


# -*- coding: utf-8 -*-
import argparse
import glob
import io
import os
import pathlib
import threading

import cv2 as cv
import lmdb
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from tqdm import tqdm

# plt.rcParams['font.sans-serif'] = ['SimHei']  # 正常显示中文
# plt.rcParams['axes.unicode_minus'] = False  # 正常显示负号

root_path = pathlib.Path('/root/autodl-tmp/hwdb')

output_path = os.path.join(root_path, pathlib.Path('lmdb'))
train_path = os.path.join(root_path, pathlib.Path('train_3755'))
val_path = os.path.join(root_path, pathlib.Path('test'))

characters = []

with open('../character-3755.txt', 'r', encoding='utf-8') as f:
    while True:
        line = f.readline()
        if not line:
            break
        char = line.strip()
        characters.append(char)


def write_cache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            if isinstance(v, bytes):
                # 图片类型为bytes
                txn.put(k.encode(), v)
            else:
                # 标签类型为str, 转为bytes
                txn.put(k.encode(), v.encode())  # 编码


def create_dataset(env, image_path, label, index):
    n_samples = len(image_path)
    # map_size=1073741824 定义最大空间是1GB
    cache = {}
    cnt = index + 1
    for idx in range(n_samples):
        # 读取图片路径和对应的标签
        image = image_path[idx]
        if not os.path.exists(image):
            print('%s does not exist' % image)
            continue
        with open(image, 'rb') as fs:
            image_bin = fs.read()
        # .mdb数据库文件保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key
        image_key = 'image-%09d' % cnt
        label_key = 'label-%09d' % cnt
        cache[image_key] = image_bin
        cache[label_key] = label
        cnt += 1
    if len(cache) != 0:
        write_cache(env, cache)
    return n_samples


def show_image(samples):
    plt.figure(figsize=(20, 10))
    for pos, sample in enumerate(samples):
        plt.subplot(4, 5, pos + 1)
        plt.imshow(sample[0])
        # plt.title(sample[1])
        plt.xticks([])
        plt.yticks([])
        plt.axis("off")
    plt.show()


def lmdb_test(root):
    env = lmdb.open(
        root,
        max_readers=1,
        readonly=True,
        lock=False,
        readahead=False,
        meminit=False)

    if not env:
        print('cannot open lmdb from %s' % root)
        return

    with env.begin(write=False) as txn:
        n_samples = int(txn.get('num-samples'.encode()))

    with env.begin(write=False) as txn:
        samples = []
        for index in range(1, n_samples + 1):
            img_key = 'image-%09d' % index
            img_buf = txn.get(img_key.encode())
            buf = io.BytesIO()
            buf.write(img_buf)
            buf.seek(0)
            try:
                img = Image.open(buf)
            except IOError:
                print('Corrupted image for %d' % index)
                return
            label_key = 'label-%09d' % index
            label = str(txn.get(label_key.encode()).decode('utf-8'))
            print(n_samples, len(img.split()), label)
            samples.append([img, label])
            if index == 5:
                # show_image(samples)
                # samples = []
                break


def lmdb_init(directory, out, left, right):
    entries = characters[left:right]
    pbar = tqdm(entries)
    n_samples = 0

    # 计算所需内存空间
    character_count = len(entries)
    image_path = glob.glob(os.path.join(directory, entries[0], '*.png'))
    image_cnt = len(image_path)
    data_size_per_img = cv.imdecode(np.fromfile(image_path[0], dtype=np.uint8), cv.IMREAD_UNCHANGED).nbytes
    # 一个类中所有图片的字节数
    data_size = data_size_per_img * image_cnt
    # 所有类的图片字节数
    total_byte = 2 * data_size * character_count
    # 创建lmdb文件
    if not os.path.exists(out):
        os.makedirs(out)
    env = lmdb.open(out, map_size=total_byte)
    for dir_name in pbar:
        image_path = glob.glob(os.path.join(directory, dir_name, '*.png'))
        label = dir_name
        n_samples += create_dataset(env, image_path, label, n_samples)
        pbar.set_description(
            f'character[{left + 1}:{right}]: {label} | nSamples: {n_samples} | total_byte: {total_byte}byte | progressing')

    write_cache(env, {'num-samples': str(n_samples)})
    env.close()


def begin(mode, left, right, valid=False):
    if mode == 'train':
        path = os.path.join(output_path, pathlib.Path(mode + '_' + str(right)))
        if not valid:
            lmdb_init(train_path, path, left=left, right=right)
        else:
            print(f"show:{valid},path:{path}")
            lmdb_test(path)
    elif mode == 'test':
        path = os.path.join(output_path, pathlib.Path(mode + '_' + str(right - left)))
        if not valid:
            lmdb_init(val_path, path, left=left, right=right)
        else:
            print(f"show:{valid},path:{path}")
            lmdb_test(path)


class MyThread(threading.Thread):
    def __init__(self, mode, left, right, valid):
        threading.Thread.__init__(self)
        self.mode = mode
        self.left = left
        self.right = right
        self.valid = valid

    def run(self):
        begin(mode=self.mode, left=self.left, right=self.right, valid=self.valid)


if __name__ == '__main__':
    """
    train_500: 3755类前500个类[1,500] = [0, 500)
    train_1000: 3755类第501到1000类[501,1000] = [500, 1000)
    train_1500: 3755类第1001到1500类[1001,1500] = [1000, 1500)
    train_2000: 3755类第1501到2000类[1501,2000] = [1500, 2000)
    train_2755: 3755类第2001到2755类[2001,2755] = [2000, 2755)
    train_3755: 3755类第2756到3755类[2756,3755] = [2755, 3755)
    test_1000: 3755类后1000类[2756,3755] = [2755, 3755)
    """
    parser = argparse.ArgumentParser()

    parser.add_argument("--train", action="store_true", help="generate train lmdb")
    parser.add_argument("--test", action="store_true", help="generate test lmdb")
    parser.add_argument("--all", action="store_true", help="generate all lmdb")
    parser.add_argument("--show", action="store_true", help="show result")
    parser.add_argument("--start", type=int, default=0, help="class start from where,default 0")
    parser.add_argument("--end", type=int, default=3755, help="class end from where,default 3755")

    args = parser.parse_args()

    train = args.train
    test = args.test
    build_all = args.all
    start = args.start
    end = args.end
    show = args.show

    if train:
        print(f"args: mode=train, [start:end)=[{start}:{end})")
        begin(mode='train', left=start, right=end, valid=show)
    if test:
        print(f"args: mode=test, [start:end)=[{start}:{end})")
        begin(mode='test', left=start, right=end, valid=show)
    if build_all:
        s = [0, 500, 1000, 1500, 2000, 2755]
        step = [500, 500, 500, 500, 755, 1000]
        m = ['5*train', '1*test']
        threads = []
        threadLock = threading.Lock()
        mode_index = 0
        for i in range(len(m)):
            tmp = m[i].strip().split("*")
            for j in range(int(tmp[0])):
                if show:
                    begin(mode=tmp[1], left=s[mode_index], right=s[mode_index] + step[mode_index], valid=show)
                else:
                    thread = MyThread(mode=tmp[1], left=s[mode_index],
                                      right=s[mode_index] + step[mode_index], valid=show)
                    threads.append(thread)
                    thread.start()
                mode_index += 1

        for t in threads:
            t.join()

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值