Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.
Publishing models
Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) to a github repository by adding a simple hubconf.py
file;
hubconf.py
can have multiple entrypoints. Each entrypoint is defined as a python function (example: a pre-trained model you want to publish).
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
How to implement an entrypoint?
Here is a code snippet specifies an entrypoint for resnet18
model if we expand the implementation in pytorch/vision/hubconf.py
. In most case importing the right function in hubconf.py
is sufficient. Here we just want to use the expanded version as an example to show how it works. You can see the full script in pytorch/vision repo
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies
variable is a list of package names required to load the model. Note this might be slightly different from dependencies required for training a model.args
andkwargs
are passed along to the real callable function.- Docstring of the function works as a help message. It explains what does the model do and what are the allowed positional/keyword arguments. It’s highly recommended to add a few examples here.
- Entrypoint function can either return a model(nn.module), or auxiliary tools to make the user workflow smoother, e.g. tokenizers.
- Callables prefixed with underscore are considered as helper functions which won’t show up in
torch.hub.list()
. - Pretrained weights can either be stored locally in the github repo, or loadable by
torch.hub.load_state_dict_from_url()
. If less than 2GB, it’s recommended to attach it to a