prep_blocks
函数
该函数用于对一系列模块(或块)进行预处理,为前向传播(forward pass)做准备,尤其是当需要在块之间清除缓存时。
源代码:
def prep_blocks(
blocks: List[Callable],
clear_cache_between_blocks: bool,
**kwargs: Any
) -> List[Callable]:
"""Prepare the blocks for the forward pass."""
prepared_blocks = [
partial(block, **kwargs)
for block in blocks
]
# Clear CUDA's GPU memory cache between blocks
if clear_cache_between_blocks:
def block_with_cache_clear(block, *args, **kwargs):
torch.cuda.empty_cache()
return block(*args, **kwargs)
prepared_blocks = [partial(block_with_cache_clear, b) for b in prepared_blocks]
return prepared_blocks
源码解读:
1. 函数定义
def prep_blocks(
blocks: List[Callable],