MMCV学习——基础篇3(fileio)| 五千字:含示例代码教程

MMCV学习——基础篇3(fileio)

fileio作为MMCV的一个文件操作的核心模块,提供了一套统一的API根据不同的后端实现不同格式文件的序列化(dump)和反序列化(load)。PS: 从v1.3.16之后,MMCV才开始支持不同后端文件的序列化和反序列化,具体细节在#1330可以看到。

1. 基础用法

 MMCV对文件的序列化和反序列化提供了统一的接口,因此调用起来十分简便:

'''Load from disk or dump to disk
'''
import mmcv

# load data from different files
data_json = mmcv.load('test.json')
data_yaml = mmcv.load('test.yaml')
data_pkl = mmcv.load('out_json.pkl')

# load data from a file-like object
with open('test.json', 'r') as file:
    data_json = mmcv.load(file, 'json')

# dump data
mmcv.dump(data_json, 'out_json.pkl')

print('json:', data_json)
print('yaml:', data_yaml)
print('pkl:', data_pkl)
'''
Ouput:
json: {'a': [1, 2, 3], 'b': None, 'c': 'hello'}
yaml: {'a': [1, 2, 3], 'b': None, 'c': 'hello'}
pkl: {'a': [1, 2, 3], 'b': None, 'c': 'hello'}
'''

2. 模块组成

fileio模块中有两个核心组件:一个是负责不同格式文件读写的FileHandler不同后端文件获取的FileClient

  • FileHandler可以根据待读写的文件后缀名自动选择对应的 handler 进行具体解析。
  • FileClient在v1.3.16之后可以根据待读写文件的URI自动选择对应的client进行具体访问。

2.1 FileHandler

 只需要自定义Handler继承BaseFileHandler,并且实现load_from_fileobjdump_to_fileobjdump_to_str这三个方法即可。下面是自定义.log文件加载为list的读取示例:

'''Convert data joined by comma in .log file to list obj mutually.
'''
# here, we need to writer a file handler inherited from BaseFileHandler and register it

@mmcv.register_handler(['txt', 'log'])
class LogFileHandler(mmcv.BaseFileHandler):
	
	str_like = True
	
    def __init__(self) -> None:
        super().__init__()
    
    def load_from_fileobj(self, file, **kwargs):
        res_list = []
        for line in file.readlines():
            res_list.append(line.split(','))
        return res_list
    
    def dump_to_fileobj(self, obj, file, **kwargs):
        for line_data in obj:
            file.write(','.join(line_data))
    
    def dump_to_str(self, obj, **kwargs):
        return str(obj) + ':str'
 
 
# Running script
data_txt = mmcv.load('test.log')
print(data_txt)
mmcv.dump(data_txt, 'out_txt.log')
data_str = mmcv.dump(data_txt, file_format='log')
print(data_str)
'''
Output:
In terminal:
[['a', ' b', ' c\n'], ['e', ' f']]
[['a', ' b', ' c\n'], ['e', ' f']]:str

In test.log:
a, b, c
e, f

In out_txt.log:
a, b, c
e, f
'''

 接下来说明几个需要注意的点:

  • str_like类属性决定了三种方法中file的类型,默认为True,即flie为StringIO类型;如果是False,则file为BytesIO类型。
  • load_from_fileobj方法就是将二进制流转为Python object。
  • dump_to_fileobj方法就是将Python object转为字节流flush到指定后端。
  • dump_to_str方法就是将Python object转为str,在mmcv.dump没有指定file参数时被调用。
  • 前面基础用法中可以直接根据path/uri路径直接load或者dump文件,但是自定义的Handler却不需要实现 load/dump_from_path这样的方法,是因为mmcv.load/dump方法中通过FileClient去解析uri从而选择对应的backend访问指定的文件,并将文件转为file object这样的字节流。

2.2 FileClient

FileClient组件主要是提供统一的不同后台文件访问API,其实做的事情就是将不同后端的文件转换为字节流。同样,自定义后台只需要继承BaseStorageBackend类,实现getget_text方法即可。

class BaseStorageBackend(metaclass=ABCMeta):

    @abstractmethod
    def get(self, filepath):
        pass

    @abstractmethod
    def get_text(self, filepath):
        pass
  • get方法就是将filepath的文件转为字节流。
  • get_text就是将filepath的文件转为字符串。

 下面是HardDiskBackend类的示例:

class HardDiskBackend(BaseStorageBackend):
    """Raw hard disks storage backend."""

    _allow_symlink = True

    def get(self, filepath: Union[str, Path]) -> bytes:
        """Read data from a given ``filepath`` with 'rb' mode.

        Args:
            filepath (str or Path): Path to read data.

        Returns:
            bytes: Expected bytes object.
        """
        with open(filepath, 'rb') as f:
            value_buf = f.read()
        return value_buf

    def get_text(self,
                 filepath: Union[str, Path],
                 encoding: str = 'utf-8') -> str:
        """Read data from a given ``filepath`` with 'r' mode.

        Args:
            filepath (str or Path): Path to read data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.

        Returns:
            str: Expected text reading from ``filepath``.
        """
        with open(filepath, 'r', encoding=encoding) as f:
            value_buf = f.read()
        return value_buf

 最后,我们这里尝试编写一段自动选择后端读取和保存Pytorch CheckPoint的程序:

# implement an interface which automatically select the corresponding backend 
import torch
from mmcv.fileio.file_client import FileClient

def load_checkpoint(path):
    file_client = FileClient.infer_client(uri=path)
    with io.BytesIO(file_client.get(path)) as buffer:
        ckpt = torch.load(buffer)
    return ckpt

def save_ckpt(ckpt, path):
    file_client = FileClient.infer_client(uri=path)
    with io.BytesIO() as buffer:
        torch.save(ckpt, buffer)
        file_client.put(buffer.getvalue(), path)

# Running script
ckpt_path = 'https://download.pytorch.org/models/resnet18-f37072fd.pth'
ckpt = load_checkpoint(ckpt_path)  # HttpbHTTPBackend
# print(type(ckpt))
print(ckpt)

save_ckpt(ckpt, 'resnet18-f37072fd.pth')
'''
Output
OrderedDict([('conv1.weight', Parameter containing:
tensor([[[[-1.0419e-02, -6.1356e-03, -1.8098e-03,  ...,  5.6615e-02,
            1.7083e-02, -1.2694e-02], ...
'''

3. 参考资料

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值