项目结构
假设我们的项目结构如下:
my_project/
|-- dataset/
| |-- __init__.py
| |-- imbalance_cifar.py
| |-- balance_cifar.py
|-- main.py
代码示例
1. dataset/imbalance_cifar.py
# dataset/imbalance_cifar.py
class IMBALANCECIFAR10:
def __init__(self, mode, cfg):
self.mode = mode
self.cfg = cfg
print(f"Initialized IMBALANCECIFAR10 with mode: {mode} and cfg: {cfg}")
def get_annotations(self):
return ["annotation1", "annotation2"]
def get_num_classes(self):
return 10
2. dataset/balance_cifar.py
# dataset/balance_cifar.py
class BALANCECIFAR10:
def __init__(self, mode, cfg):
self.mode = mode
self.cfg = cfg
print(f"Initialized BALANCECIFAR10 with mode: {mode} and cfg: {cfg}")
def get_annotations(self):
return ["annotation3", "annotation4"]
def get_num_classes(self):
return 10
3. dataset/__init__.py
# dataset/__init__.py
from .imbalance_cifar.py import *
from .balance_cifar.py import *
4. main.py
# main.py
from dataset import *
# 模拟配置文件中的类名字符串
cfg = {
"DATASET": {
"DATASET": "IMBALANCECIFAR10"
}
}
# 动态实例化类
dataset_class = eval(cfg["DATASET"]["DATASET"])
train_set = dataset_class("train", cfg)
valid_set = dataset_class("valid", cfg)
# 调用方法
annotations = train_set.get_annotations()
num_classes = train_set.get_num_classes()
print("Annotations:", annotations)
print("Number of classes:", num_classes)
注意:
如果没有 __init__.py
文件,使用 from ... import ...
的方式导入模块将会失败。