【Database】兼容 Python2 / Python3 适配编码的文件型数据容器

60 篇文章 1 订阅
15 篇文章 0 订阅

0x00 前言

训练内存轻量化
最近又在训练模型(炼丹),以前老抱怨,区区2万 samples 也好意思叫大数据,近期的任务似乎听到了我这个抱怨,纷纷都是什么“1700万个句子”,“4000个文档”的数据,对服务器内存一次次的进行着冲击。
虽说我之前已经写过一个CIR(CorpusIterationReader)类实现的文章用来解决类似问题(哎?我那篇文章哪去了,被吃了么……emmmm,以后再重发一次吧。)但是那个类也只能让 pivot 以 “文件指针+instance指针” 的方式进行顺序存取,不是很好处理 “shuffle后随机存取” 的情况,再者,“每个文件中包括多个samples” 的设计在多进程中容易产生冲突。
经 cyx 学长提醒,可以考虑每个 sample 单独作为一个文件。(我觉得吧,这个也会有个小问题,就是这个文件夹里千万别一不小心按一下 ls -all 不然要等好半天了哈哈哈)于是基于学长 OneFileDB 的设计,重构并实现了一些这种处理方案的工具类及工具函数,便于我基于 PyTorch 的模型得以正常训练。

跨版本编码兼容
实现过程中,由于 python2 的老项目和 python3 的新项目都需要使用,于是编码也是一个大难题,参考 Bert 里的 convert_to_unicode,研究了 ujsonjson.JSONEncoder 是如何将不同编码处理成 unicode 并存为 json 格式,把相关的实现也放了进去。

0x01 用法介绍

对于一个 json 文件而言,通常是一个 list,里面包含多个dict的形式存储的 samples
对于模型而言,我们需要的是,在 sample 的数量足够多时,还要能够较快地通过下标(或者key)来获取到对应的 sample 喂给模型。

# JSON EXAMPLE
j = [{'info': {'sid': 'test1'},
      'words': [{'id': 'w0', 'word': u'电'.encode('utf-8')},
                {'id': 'w1', 'word': u'话'.encode('utf-8')},
                {'id': 'w2', 'word': '[unused10]'},
                {'id': 'w3', 'word': '0'},
                {'id': 'w4', 'word': '2'},
                {'id': 'w5', 'word': '1'},
                {'id': 'w6', 'word': '-'},
                {'id': 'w7', 'word': '3'},
                {'id': 'w27', 'word': '0'}],
      'entities': [], 'relations': []},
     {'info': {'sid': 'test2'},
      'words': [{'id': 'w0', 'word': u'地'.encode('utf-8')},
                {'id': 'w1', 'word': u'址'.encode('utf-8')}],
      'entities': [], 'relations': []}]

我们有三种方式来进行存储:

  • OneFileDB,即单文件存储,和我们平时直接读一个文件进来没有两样
  • FolderDB,文件夹存储,文件夹中的每一个文件是一个 sample
  • CFolderDB,加密文件夹存储,是 FolderDB 的继承类,不同点在于 sample 是加密压缩的
# 特别的,我们可以将一个json文件读入为 OneFileDB 后,
# 通过成员函数 `transfer_to_folderdb(path=<folder_path>)` 生成一个 FolderDB
db_of = Database('./test.json')  # OneFileDB
db_f1 = db_of.transfer_to_folderdb('./test')  # FolderDB
db_f2 = Database('./test')  # FolderDB
db_cf = Database('./test.cfolder')  # CFolderDB

这几种 DB 的使用,也是通常的写入,下标读取,遍历,获得 samples 长度等。
而对于 Folder 类的 DB 来说,还有额外的 append 函数,方便其增加新的 samples。

db.write(samples=j)
db_f.append(samples=j)
for idx, item in enumerate(db):
    print(idx, item)
print(db.__len__())
print(db[1])

0x02 Source Code

Database 主类

# coding: utf-8
# ==========================================================================
#   Copyright (C) 2016-2020 All rights reserved.
#
#   filename : training_dbs_new.py
#   origin   : cyx / caoyixuan
#   author   : chendian / okcd00@qq.com
#   date     : 2020-07-21
#   desc     : An alternative to the original database class (multi-json).
#              can be called as a dict or a list.
# ==========================================================================

class Database(object):
    """
    A unified wrapper for OneFileDB, FolderDB
    """

    def __init__(self, path, samples=None, n_samples=None, read_only=True, load_now=False):
        if samples is not None:
            db = OneFileDB(path, samples, n_samples=n_samples)
        else:
            mode = self.determine_mode(path)
            logging.info('database mode: {}'.format(mode))
            if mode == 'all_samples_one_file':
                db = OneFileDB(path, samples=None, n_samples=n_samples,
                               read_only=read_only, load_now=load_now)
            elif mode == 'one_sample_per_file':
                db = FolderDB(path, n_samples=n_samples,
                              read_only=read_only, load_now=load_now)
            elif mode == 'cfolder':
                db = CFolderDB(path, n_samples=n_samples,
                               read_only=read_only, load_now=load_now)
            else:
                raise ValueError("Unknown mode: {}".format(mode))

        self.db = db
        self.sids = db.sids

    @staticmethod
    def determine_mode(label_path):
        if label_path.endswith('.json'):
            mode = 'all_samples_one_file'
        elif label_path.endswith('.cfolder') or label_path.endswith('.cfolder/'):
            mode = 'cfolder'
        else:  # directory path without postfix
            mode = 'one_sample_per_file'
        return mode

    def write(self, samples):
        return self.db.write(samples)

    def get_by_sid(self, sid):
        return self.db.get_by_sid(sid)

    def __getitem__(self, item):
        if isinstance(item, slice):
            return self.sl(item)
        return self.db[item]

    def sl(self, key):
        start, stop, step = key.indices(len(self))
        for i in range(start, stop, step):
            yield self.db[i]

    def __len__(self):
        return self.db.__len__()

    def __iter__(self):
        return self.db.__iter__()

    def next(self):
        return self.db.next()

    @property
    def all_samples(self):
        return self.db.all_samples


if __name__ == "__main__":
    sd = Database('./test')

DB基类与三种衍生

class TrainDBBase(object):
    """
    An immutable dataset once write.
    """

    def write(self, samples):
        """save samples"""
        raise NotImplementedError()

    def get_by_sid(self, sid):
        """get sample by sid"""
        raise NotImplementedError()

    def __getitem__(self, item):
        """ get sample by index in dataset"""
        raise NotImplementedError()

    def __len__(self):
        """return the number of samples in this dataset"""
        raise NotImplementedError()

    def __iter__(self):
        self.n = 0
        return self

    def next(self):
        if self.n == self.__len__():
            raise StopIteration
        n = self.n
        self.n += 1
        return self[n]

    def __next__(self):
        return self.next()

    @property
    def all_samples(self):
        """return all samples in this dataset"""
        return [self[i] for i in range(len(self))]


class FolderDB(TrainDBBase):
    """
    一个sample写到一个文件里,一个DB就是一个文件夹,只能按照文件名进行索引
    NEW: 也可以按下标遍历
    """

    def __init__(self, folder, n_samples=None, read_only=True, load_now=False):
        self.folder = folder
        self.compress = False
        self.n_samples = n_samples
        self.sids = None
        if load_now:
            self.load_register()

    def write(self, samples):
        write_one_sample_per_file(samples, self.folder)

    def append(self, samples):
        append_write_one_sample_per_file(samples, self.folder)

    def get_by_sid(self, sid):
        file_path = path_join(self.folder, sid)
        sample = json.load(open(file_path))
        return sample

    def __getitem__(self, index):
        self.load_register()
        sid = self.sids[index]
        return self.get_by_sid(sid)

    def __len__(self):
        self.load_register()
        return len(self.sids)

    def load_register(self):
        if self.sids is not None:
            return
        sids = load_register(self.folder)
        if self.n_samples:
            sids = sids[: self.n_samples]
        self.sids = sids
        assert len(self.sids) == len(set(self.sids)), 'exist duplicated sids'


class CFolderDB(FolderDB):
    """A json-encrypted FolderDB"""
    def write(self, samples):
        write_one_sample_per_file(samples, self.folder, compress=True)

    def get_by_sid(self, sid):
        file_path = path_join(self.folder, sid)
        sample = json_load(path=file_path, mode='r', decrypt=True)
        # sample = json.loads(zlib.decompress(open(file_path, 'rb').read()).decode('utf-8'))
        return sample


class OneFileDB(TrainDBBase):
    """ Single file as a DB"""
    def __init__(self, file_path, samples=None, n_samples=None, read_only=True, load_now=False):
        self.file_path = file_path
        self.sids = None
        self.samples = None
        self.compress = False
        self.sid_to_sample = None
        self.n_samples = n_samples
        if samples is not None:
            self.set_samples(samples)
        else:
            if load_now:
                self.load()

    def write(self, samples):
        json_dump(
            obj_=samples, path=self.file_path,
            mode='w', encrypt=self.compress)

    def get_by_sid(self, sid):
        self.load()
        return self.sid_to_sample[sid]

    def load(self):
        if self.samples is not None:
            return
        samples = json_load(
            path=self.file_path, mode='r',
            decrypt=self.compress)
        self.set_samples(samples)

    def set_samples(self, samples):
        # make a minor database for testing.
        if self.n_samples:
            samples = samples[: self.n_samples]
        self.samples = samples
        self.sids = [s['info']['sid'] for s in self.samples]
        self.sid_to_sample = {s['info']['sid']: s for s in self.samples}

    def transfer_to_folderdb(self, path):
        write_one_sample_per_file(
            answers=self.samples,
            folder=path,
            compress=self.compress)
        return Database(path=path)

    def __getitem__(self, item):
        self.load()
        return self.samples[item]

    def __len__(self):
        self.load()
        return len(self.samples)

Magic Tools

这种任务,最麻烦的就是 Python2 和 Python3 之间的兼容性,兼容性最麻烦的又体现在编码上,Python2的 unicode 编码即Python3的 str 编码,Python2的 str 编码即Python3的 bytes 编码,于是

头文件及依赖
from __future__ import unicode_literals
from six import PY2, PY3
import logging
import os
import zlib
import numpy as np
from io import open
JSON_MODULE = None
JSON编码相关
try:
    # if you have ujson, it will be faster
    # but the calling method is different.
    import ujson as json
    JSON_MODULE = 'ujson'
except ImportError:
    import json
    JSON_MODULE = 'json'


    class JsonBytesEncoder(json.JSONEncoder):
        # json.dumps
        def default(self, obj):
            # if isinstance(obj, np.ndarray):
            #     return obj.tolist()  # for further support.
            if isinstance(obj, bytes):
                return convert_to_unicode(obj)
                # return str(obj, encoding='utf-8')
            return json.JSONEncoder.default(self, obj)


def json_dumps(obj_, encrypt=False):
    if JSON_MODULE == 'json':
        _json_str = json.dumps(
            obj_, cls=JsonBytesEncoder)
    elif JSON_MODULE == 'ujson':
        if int(json.__version__[0]) < 2:
            # standard ujson-1.35 for python2.7
            _json_str = json.dumps(obj_)
        else:  # standard ujson-3.0.0 for python3.6
            _json_str = json.dumps(
                obj_, reject_bytes=False)
    else:
        _json_str = json.dumps(obj_)
    if encrypt:
        return zlib_encrypt(_json_str)
    return _json_str


def json_dump(obj_, path=None, mode='w', stream=None, encrypt=False):
    # the same as json.dump(zlib_encrypt(obj_), open(path, 'w'))
    # use 'w', not 'wb' in python3 for
    # TypeError: a bytes-like object is required, not 'str'
    if encrypt:  # the zlib.compress transfers data into bytes
        mode = 'wb'
    if stream is not None:
        # stream contains path and mode
        stream.write(json_dumps(obj_, encrypt))
    else:
        with open(path, mode) as f:
            f.write(json_dumps(obj_, encrypt))


def json_loads(str_, decrypt=False):
    if decrypt:
        str_ = zlib_decrypt(str_)
    # all kinds of json have the same loads()
    data = json.loads(str_)
    return data


def json_load(path, mode='r', decrypt=False):
    # the same as json.load(open(path, mode))
    if decrypt:  # the zlib.compress transfers data into bytes
        mode = 'rb'
    with open(path, mode) as f:
        obj_ = json_loads(f.read(), decrypt)
    return obj_


def zlib_encrypt(data):
    # return an encrypted string
    if isinstance(data, (list, dict, tuple)):
        j_str = json_dumps(data)  # data-structure to json-str
    else:  # to unicode (py2-unicode or py3-str)
        j_str = convert_to_unicode(data)
    # zlib only allow bytes-like inputs
    return zlib.compress(convert_to_bytes(j_str))


def zlib_decrypt(str_):
    # return a json_str in unicode
    b_str = zlib.decompress(str_)
    return convert_to_unicode(b_str)


def path_join(*args):
    return ''.join(convert_to_unicode(each) for each in args)


def write_data(stream, text, encoding='unicode'):
    # once write **text** into a file, need to know
    # the basestring for py2 and py3 are different
    if encoding in ['unicode', 'u']:
        stream.write(convert_to_unicode(text))
    elif encoding in ['bytes', 'utf-8', 'b']:
        stream.write(convert_to_bytes(text))
    else:  # others
        stream.write(text)


def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if isinstance(text, (int, float)):
        text = '{}'.format(text)
    if PY3:
        if isinstance(text, str):  # py3-str is unicode
            return text
        elif isinstance(text, bytes):  # py3-bytes is py2-str
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif PY2:
        if isinstance(text, str):  # py2-str is py3-bytes
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):  # py2-unicode is py3-str
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")


def convert_to_bytes(text):
    if PY2 and isinstance(text, str):
        return text
    elif PY3 and isinstance(text, bytes):
        return text
    u_text = convert_to_unicode(text)
    return u_text.encode('utf-8')


def recursive_encoding_unification(cur_node):
    from collections import OrderedDict
    reu = recursive_encoding_unification

    if isinstance(cur_node, (list, tuple)):
        return type(cur_node)(
            [reu(item) for item in cur_node])
    elif isinstance(cur_node, (dict, OrderedDict)):
        return type(cur_node)(
            [(reu(k), reu(v)) for (k, v) in cur_node.items()])
    elif isinstance(cur_node, (int, float)):
        return cur_node
    elif cur_node is None:
        return None
    else:  # str, bytes, unicode
        # only convert leaf-nodes
        return convert_to_unicode(cur_node)


def json_unicode(json_dict):
    return recursive_encoding_unification(json_dict)

文件型DB相关

def write_one_sample_per_file(answers, folder, compress=False):
    register = ['{}'.format(s['info']['sid']) for i, s in enumerate(answers)]
    if not os.path.exists(folder):
        os.mkdir(folder)
    with open(path_join(folder, 'register'), 'w') as fw:
        # 'w' for py2 and py3 is different
        write_data(fw, '\n'.join(register))
    for s in answers:
        file_path = path_join(folder, s['info']['sid'])
        json_dump(s, path=file_path, encrypt=compress)


def append_write_one_sample_per_file(answers, folder, compress=False):
    assert os.path.isdir(folder), 'folder should exist if you want to append to existing dataset'
    sids = load_register(folder)
    conflict_sids = set(sids).intersection([s['info']['sid'] for s in answers])
    assert not conflict_sids, 'some sids already exist: {}'.format(list(conflict_sids)[:10])
    new_register = ['{}'.format(s['info']['sid']) for i, s in enumerate(answers)]
    # saving bytes is faster, but here is 'append' without 'b'
    # remain storing as source text
    with open(path_join(folder, 'register'), 'a') as fw:
        # 'w' for py2 and py3 is different
        write_data(fw, '\n')
        write_data(fw, '\n'.join(new_register))

    f = 0
    for s in answers:
        try:
            sid_str = convert_to_unicode(s['info']['sid'])
            json_dump(obj_=s, path=path_join(folder, sid_str), encrypt=False)
        except OverflowError:
            logging.warn('{} save error'.format(s['info']['sid']))
            f += 1
            if f > 30:
                break


def load_register(folder, n_samples=None):
    sids = []
    # loading bytes is faster (append with 'ab+', loading with 'rb')
    with open(path_join(folder, 'register'), 'r') as fr:
        if n_samples is None:
            # faster list-construction
            sids = [line.strip().split(',')[-1] for line in fr]
        else:  # custom n_samples is usually small,
            for line in fr:  # list-appending will be faster.
                sid = line.strip().split(',')[-1]
                sids.append(sid)
                if n_samples is not None:
                    if len(sids) >= n_samples:
                        break
    return sids


def random_ints(n):
    """return n random ints that are distinct"""
    assert n < 10 ** 9, 'Too many distinct numbers asked.'
    row_randoms = np.random.randint(0, np.iinfo(np.int64).max, 2 * n)
    uniques = np.unique(row_randoms)
    while len(uniques) < n:
        r = np.random.randint(0, np.iinfo(np.int64).max, 2 * n)
        uniques = np.unique(np.stack([uniques, r]))
    return uniques[:n]
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

糖果天王

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值