系列文章目录
文章目录
一、依赖反转 Inversion of Control
- 设计一个玩玩具场景的伪程序,①有各种各样的玩具; ②可以玩各种玩具;③玩具的生成由玩具工厂生成
1. 直接设计一个无依赖反转的类
class Car:
def play(self):
print("Vroom! I'm a toy car.")
class Doll:
def play(self):
print("Hello! I'm a lovely doll.")
class ToyFactory:
def produce_toy(self, toy_type): #问题:如果添加新的玩具,需要修改这个方法
if toy_type == "Car":
return Car()
elif toy_type == "Doll":
return Doll()
class Playroom:
def __init__(self):
self.toy_factory = ToyFactory() #问题:如果修改ToyFactory类,这里也需要修改,例如传参
def play_with_toys(self):
car = self.toy_factory.produce_toy("Car")
doll = self.toy_factory.produce_toy("Doll")
car.play()
doll.play()
p = Playroom()
p.play_with_toys()
存在的问题:
- 直接在一个类中,进行了实例化另外一个类
- 如果需要添加新的玩具,需要对
ToyFactory
进行直接修改 - 如果
ToyFactory
修改,需要传递参数,那么必须修改Platroom
里的代码
2. 依赖反转简单版本(解决问题3)
- 简单版本只解决了上面的3的问题,因为不是在
Playroom
里进行实例化,而是外面。所以,ToyFactory
不管任何改变,只要Playroom
里的两个方法没有变,这个程序就可以正常运行
class Car:
def play(self):
print("Vroom! I'm a toy car.")
class Doll:
def play(self):
print("Hello! I'm a lovely doll.")
class ToyFactory:
def produce_toy(self, toy_type): #问题:如果添加新的玩具,需要修改这个方法
if toy_type == "Car":
return Car()
elif toy_type == "Doll":
return Doll()
class Playroom:
def __init__(self, toy_factory: ToyFactory): #传递一个工厂实例
self.toy_factory = toy_factory
def play_with_toys(self):
car = self.toy_factory.produce_toy("Car")
doll = self.toy_factory.produce_toy("Doll")
car.play()
doll.play()
toy_factory = ToyFactory()
p = Playroom(toy_factory)
p.play_with_toys()
未解决问题:
2. 如果需要添加新的玩具,需要对ToyFactory
进行直接修改
3. 如果ToyFactory
里对produce_toy
这个方法进行了修改,那么Playroom
里的方法
3. 依赖反转进阶版本:注册机制(解决问题2)
在工厂模式中,注册机制的主要目标是避免每次增加新产品类型时都需要修改工厂类,从而实现开闭原则(对扩展开放,对修改封闭)。通过注册机制,新的产品类型可以在运行时动态注册到工厂中,而工厂类本身不需要修改。
class Car:
def play(self):
print("Vroom vroom! I'm a toy car.")
class Doll:
def play(self):
print("Hi there! I'm a lovely doll.")
#添加新的玩具
class Train:
def play(self):
print("Choo choo! I'm a toy train.")
#更改一个不违反开闭原则的方法
class ToyFactory:
def __init__(self):
self.toy_types = {}
def register_toy(self, toy_type:str, toy_cls:object):
self.toy_types[toy_type] = toy_cls
def produce_toy(self, toy_type):
return self.toy_types[toy_type]()
#palyroom
class Playroom:
def __init__(self, toy_factory):
self.toy_factory = toy_factory
def play_with_toys(self):
car = self.toy_factory.produce_toy("Car")
doll = self.toy_factory.produce_toy("Doll")
train = self.toy_factory.produce_toy("Train")
car.play()
doll.play()
train.play()
#测试
toy_factory = ToyFactory()
toy_factory.register_toy("Car", Car)
toy_factory.register_toy("Doll", Doll)
toy_factory.register_toy("Train", Train)
p = Playroom(toy_factory)
p.play_with_toys()
未解决问题:
3. 如果ToyFactory
里对produce_toy
这个方法进行了修改,那么Playroom
里的方法无法使用
4. 依赖反转进阶版本:适配器模式(解决问题3)
该版本是最最终的版本,添加一个适配器,无论ToyFactory
里的方法如何修改,Playroom
只需要调用自己的方法即可;并且该版本将产品进行了抽象
from abc import ABC, abstractmethod
# 定义玩具的抽象接口
class Toy(ABC):
@abstractmethod
def play(self):
pass
# 实现具体的玩具类
class Car(Toy):
def play(self):
print("Vroom! I'm a toy car.")
class Doll(Toy):
def play(self):
print("Hello! I'm a lovely doll.")
class Train(Toy): # 新增的玩具类
def play(self):
print("Choo Choo! I'm a toy train.")
# 定义工厂的抽象接口
class ToyFactory(ABC):
@abstractmethod
def produce_toy(self, toy_type):
pass
# 实现具体的工厂类,但方法名改为 create_pro,并使用注册机制
class ConcreteToyFactory:
def __init__(self):
self._creators = {}
def register_toy(self, toy_type, creator):
self._creators[toy_type] = creator
def create_pro(self, toy_type):
creator = self._creators.get(toy_type)
if not creator:
raise ValueError(f"Unknown toy type: {toy_type}")
return creator()
# 适配器类,使 ConcreteToyFactory 符合 ToyFactory 接口
class ToyFactoryAdapter(ToyFactory):
def __init__(self, concrete_toy_factory):
self.concrete_toy_factory = concrete_toy_factory
def produce_toy(self, toy_type):
return self.concrete_toy_factory.create_pro(toy_type)
# 高层模块依赖于抽象工厂,而不是具体实现
class Playroom:
def __init__(self, toy_factory: ToyFactory):
self.toy_factory = toy_factory
def play_with_toys(self):
car = self.toy_factory.produce_toy("Car")
doll = self.toy_factory.produce_toy("Doll")
train = self.toy_factory.produce_toy("Train") # 新增的玩具使用示例
car.play()
doll.play()
train.play()
# 通过适配器将具体的工厂实现传递给 Playroom
concrete_toy_factory = ConcreteToyFactory()
concrete_toy_factory.register_toy("Car", Car)
concrete_toy_factory.register_toy("Doll", Doll)
concrete_toy_factory.register_toy("Train", Train) # 注册新的玩具类
toy_factory_adapter = ToyFactoryAdapter(concrete_toy_factory) #添加适配器
playroom = Playroom(toy_factory_adapter)
playroom.play_with_toys()
总结:
- 非必要不要再一个类里面直接实例化另外一个类
- 如果一个类经常发生变化和改动,那么添加一个适配器,给调用它方法的类
- 简单的工厂模式,最好使用注册机制,这样每次有新的产品不需要修改工厂
二、类方法
1.使用类方法简化数据处理类
不使用类方法:由于是用来测试,没有传递真正的数据,所以每次都要先实例化传递一个空的数据
class DataProcessor:
def __init__(self, data):
self.data = data # take data in from memory
def process_data(self):
# complicated code to process data in memory
...
def from_csv(self, filepath):
self.data = pd.read_csv(filepath)
# Using the class without initial data in memory
processor = DataProcessor(data=None)
processor.from_csv("path_to_your_file.csv")
processor.process_data()
使用类方法: 由于DataProcessor
肯定是必须是处理数据的,是他类本身必须的功能,那么读取数据也是它必须有的方法,所以将所有的读取数据的方法都改成类方法,这样就可以直接读取数据
import pandas as pd
class DataProcessor:
def __init__(self, data):
self.data = data
def process_data(self):
print(f"Processing {self.data}")
@classmethod
def from_csv(cls, filepath):
data = filepath*2 #模拟读取了数据
return cls(data)
@classmethod
def from_parquet(cls, filepath):
data = pd.read_parquet(filepath)
return cls(data)
# Instantiating and using the class with classmethod
processor = DataProcessor.from_csv("path_to_your_file.csv")
processor.process_data()
2. 使用类方法读取配置文件(预设功能)
我们使用配置文件驱动,直接传递类的self也是这个原理,不传递使用默认参数,传递使用传递的参数,如果需要从配置文件读取,可以直接使用类方法直接读取;这样需要单独的将这个方法写在类的外面,因为读取配置文件的方法,应该是类本身应该具有的功能。
class MyXGBModel:
def __init__(self, learning_rate=0.1, n_estimators=100, max_depth=3):
...
def _create_model(self):
...
@classmethod
def from_config_file(cls, file_path):
with open(file_path, 'r') as file:
config = json.load(file)
return cls(**config)
# Usage
model_from_config = MyXGBModel.from_config_file('config.json')
- 例如:现在有一个场景,我们的模型已经创建好了,但是该模型有2个模式的配置文件,我们可以根据我们需要选择不同的配置文件,来运行这个类.
class MyXGBoostModel:
def __init__(self, learning_rate=0.1, n_estimators=100, max_depth=3):
...
def _create_model(self):
...
@classmethod
def from_config_file(cls, file_path):
...
@classmethod
def quick_start(cls):
default_params = {'learning_rate': 0.05, 'n_estimators': 100, 'max_depth': 4}
return cls(**default_params)
@classmethod
def high_accuracy(cls):
high_acc_params = {'learning_rate': 0.01, 'n_estimators': 500, 'max_depth': 10}
return cls(**high_acc_params)
# Usage
quick_start_model = MyXGBoostModel.quick_start()
high_accuracy_model = MyXGBoostModel.high_accuracy()
3.一个数据库类的例子
import os
from getpass import getpass
class DatabaseConnector:
def __init__(self, account, user, password):
self.account = account
self.user = user
self._password = getpass("enter the DB Password:") # Keep password private
self._connection = None
@classmethod
def development_config(cls):
return cls(
account=os.environ.get("DEV_DB_ACCOUNT"),
user=os.environ.get("DEV_DB_USER"),
password=os.environ.get("DEV_DB_PASSWORD") or getpass("Enter Dev DB Password: ")
)
@classmethod
def production_config(cls):
return cls(
account=os.environ.get("PROD_DB_ACCOUNT"),
user=os.environ.get("PROD_DB_USER"),
password=os.environ.get("PROD_DB_PASSWORD") or getpass("Enter Prod DB Password: ")
)
@property
def params(self):
# Exclude password from the public parameters
return {
"account": self.account,
"user": self.user
}
def initialize_connection(self):
if not self._connection:
# define create_database_connection based on your specific database (e.g. postgres, snowflake, redshift, etc)
self._connection = create_database_connection(self.account, self.user, self._password)
def query_data(self, query):
if not self._connection:
raise Exception("Database connection not initialized")
return execute_query(self._connection, query)
# Dev Database
dev_db_conn = DatabaseConnector.development_config()
dev_db_conn.initialize_connection()
dev_data = dev_db_conn.query_data("SELECT * FROM dev_table_name")
# Prod Database
prod_db_conn = DatabaseConnector.production_config()
prod_db_conn.initialize_connection()
prod_data = prod_db_connector.query_data("SELECT * FROM prod_table_name")