一堆乱七八糟
conda create -n parseq python=3.9 -y
conda activate parseq
# CUDA 11.3
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
# CUDA 10.2
pip install torch==1.10.1+cu102 torchvision==0.11.2+cu102 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu102/torch_stable.html
# Use specific platform build. Other PyTorch 1.13 options: cu116, cu117, rocm5.2
platform=CUDA 10.2
# Generate requirements files for specified PyTorch platform
make torch-CUDA 10.2
Install the project and core + train + test dependencies. Subsets: [train,test,bench,tune]
pip install -r requirements/core.cpu.txt -e .[train,test]
pip install -r requirements.txt 太难了就放弃吧,缺少module再注意pip就好,如果安装版本报错,可以参考一下 requirements.txt 指明的版本
一般都会有个 pip install -e .
安装项目本身
./test.py outputs///checkpoints/last.ckpt # or use the released weights: ./test.py pretrained=parseq
python test.py work_dirs/parseq-bb5792a6.pt
python read.py work_dirs/parseq-bb5792a6.pt refine_iters:int=2 decode_ar:bool=false --images demo_images/*
做不下去了真的TTTTTTTT哭死
train 和test都运行不了
存档
TypeError: init() missing 21 required positional arguments: [When running test.py and read.py]
环境按照这个装的,只教了怎么用自己自制的数据集,类型是lmdb,tools里面有转化器
python read.py work_dirs/parseq-bb5792a6.pt refine_iters:int=2 decode_ar:bool=false --images demo_images/*
指定某张显卡,
CUDA_VISIBLE_DEVICES=1 python read.py pretrained=parseq refine_iters:int=2 decode_ar:bool=false --images demo_images/*
python read.py pretrained=parseq --images demo_images/*
python read.py work_dirs/parseq-bb5792a6.pt --images demo_images/* 报错
cp -r “DeepSolo-main/datasets/ocr_en_422k/” parseq-main/demo_images/ocr_en
cp -r “DeepSolo-main/datasets/ocr_zh_230920_381k/” parseq-main/demo_images/ocr_zh
python read.py pretrained=parseq --images demo_images/ocr_en_422k/* --output outputs/ocr_en
python read.py pretrained=parseq --images demo_images/ocr_zh_230920_381k/* --output outputs/ocr_zh
识别不到中文
lmdb格式数据集转换代码
# -*- coding: utf-8 -*-
import argparse
import glob
import io
import os
import pathlib
import threading
import cv2 as cv
import lmdb
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from tqdm import tqdm
# plt.rcParams['font.sans-serif'] = ['SimHei'] # 正常显示中文
# plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
root_path = pathlib.Path('/root/autodl-tmp/hwdb')
output_path = os.path.join(root_path, pathlib.Path('lmdb'))
train_path = os.path.join(root_path, pathlib.Path('train_3755'))
val_path = os.path.join(root_path, pathlib.Path('test'))
characters = []
with open('../character-3755.txt', 'r', encoding='utf-8') as f:
while True:
line = f.readline()
if not line:
break
char = line.strip()
characters.append(char)
def write_cache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.items():
if isinstance(v, bytes):
# 图片类型为bytes
txn.put(k.encode(), v)
else:
# 标签类型为str, 转为bytes
txn.put(k.encode(), v.encode()) # 编码
def create_dataset(env, image_path, label, index):
n_samples = len(image_path)
# map_size=1073741824 定义最大空间是1GB
cache = {}
cnt = index + 1
for idx in range(n_samples):
# 读取图片路径和对应的标签
image = image_path[idx]
if not os.path.exists(image):
print('%s does not exist' % image)
continue
with open(image, 'rb') as fs:
image_bin = fs.read()
# .mdb数据库文件保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key
image_key = 'image-%09d' % cnt
label_key = 'label-%09d' % cnt
cache[image_key] = image_bin
cache[label_key] = label
cnt += 1
if len(cache) != 0:
write_cache(env, cache)
return n_samples
def show_image(samples):
plt.figure(figsize=(20, 10))
for pos, sample in enumerate(samples):
plt.subplot(4, 5, pos + 1)
plt.imshow(sample[0])
# plt.title(sample[1])
plt.xticks([])
plt.yticks([])
plt.axis("off")
plt.show()
def lmdb_test(root):
env = lmdb.open(
root,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False)
if not env:
print('cannot open lmdb from %s' % root)
return
with env.begin(write=False) as txn:
n_samples = int(txn.get('num-samples'.encode()))
with env.begin(write=False) as txn:
samples = []
for index in range(1, n_samples + 1):
img_key = 'image-%09d' % index
img_buf = txn.get(img_key.encode())
buf = io.BytesIO()
buf.write(img_buf)
buf.seek(0)
try:
img = Image.open(buf)
except IOError:
print('Corrupted image for %d' % index)
return
label_key = 'label-%09d' % index
label = str(txn.get(label_key.encode()).decode('utf-8'))
print(n_samples, len(img.split()), label)
samples.append([img, label])
if index == 5:
# show_image(samples)
# samples = []
break
def lmdb_init(directory, out, left, right):
entries = characters[left:right]
pbar = tqdm(entries)
n_samples = 0
# 计算所需内存空间
character_count = len(entries)
image_path = glob.glob(os.path.join(directory, entries[0], '*.png'))
image_cnt = len(image_path)
data_size_per_img = cv.imdecode(np.fromfile(image_path[0], dtype=np.uint8), cv.IMREAD_UNCHANGED).nbytes
# 一个类中所有图片的字节数
data_size = data_size_per_img * image_cnt
# 所有类的图片字节数
total_byte = 2 * data_size * character_count
# 创建lmdb文件
if not os.path.exists(out):
os.makedirs(out)
env = lmdb.open(out, map_size=total_byte)
for dir_name in pbar:
image_path = glob.glob(os.path.join(directory, dir_name, '*.png'))
label = dir_name
n_samples += create_dataset(env, image_path, label, n_samples)
pbar.set_description(
f'character[{left + 1}:{right}]: {label} | nSamples: {n_samples} | total_byte: {total_byte}byte | progressing')
write_cache(env, {'num-samples': str(n_samples)})
env.close()
def begin(mode, left, right, valid=False):
if mode == 'train':
path = os.path.join(output_path, pathlib.Path(mode + '_' + str(right)))
if not valid:
lmdb_init(train_path, path, left=left, right=right)
else:
print(f"show:{valid},path:{path}")
lmdb_test(path)
elif mode == 'test':
path = os.path.join(output_path, pathlib.Path(mode + '_' + str(right - left)))
if not valid:
lmdb_init(val_path, path, left=left, right=right)
else:
print(f"show:{valid},path:{path}")
lmdb_test(path)
class MyThread(threading.Thread):
def __init__(self, mode, left, right, valid):
threading.Thread.__init__(self)
self.mode = mode
self.left = left
self.right = right
self.valid = valid
def run(self):
begin(mode=self.mode, left=self.left, right=self.right, valid=self.valid)
if __name__ == '__main__':
"""
train_500: 3755类前500个类[1,500] = [0, 500)
train_1000: 3755类第501到1000类[501,1000] = [500, 1000)
train_1500: 3755类第1001到1500类[1001,1500] = [1000, 1500)
train_2000: 3755类第1501到2000类[1501,2000] = [1500, 2000)
train_2755: 3755类第2001到2755类[2001,2755] = [2000, 2755)
train_3755: 3755类第2756到3755类[2756,3755] = [2755, 3755)
test_1000: 3755类后1000类[2756,3755] = [2755, 3755)
"""
parser = argparse.ArgumentParser()
parser.add_argument("--train", action="store_true", help="generate train lmdb")
parser.add_argument("--test", action="store_true", help="generate test lmdb")
parser.add_argument("--all", action="store_true", help="generate all lmdb")
parser.add_argument("--show", action="store_true", help="show result")
parser.add_argument("--start", type=int, default=0, help="class start from where,default 0")
parser.add_argument("--end", type=int, default=3755, help="class end from where,default 3755")
args = parser.parse_args()
train = args.train
test = args.test
build_all = args.all
start = args.start
end = args.end
show = args.show
if train:
print(f"args: mode=train, [start:end)=[{start}:{end})")
begin(mode='train', left=start, right=end, valid=show)
if test:
print(f"args: mode=test, [start:end)=[{start}:{end})")
begin(mode='test', left=start, right=end, valid=show)
if build_all:
s = [0, 500, 1000, 1500, 2000, 2755]
step = [500, 500, 500, 500, 755, 1000]
m = ['5*train', '1*test']
threads = []
threadLock = threading.Lock()
mode_index = 0
for i in range(len(m)):
tmp = m[i].strip().split("*")
for j in range(int(tmp[0])):
if show:
begin(mode=tmp[1], left=s[mode_index], right=s[mode_index] + step[mode_index], valid=show)
else:
thread = MyThread(mode=tmp[1], left=s[mode_index],
right=s[mode_index] + step[mode_index], valid=show)
threads.append(thread)
thread.start()
mode_index += 1
for t in threads:
t.join()