关于collate_fn
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader
def collate_fn(batch):
max_len = max(len(x) for x in batch)
padded_batch = [x + [0] * (max_len - len(x)) for x in batch]
return torch.tensor(padded_batch)
class TextDataset(Dataset):
def __init__(self, texts):
self.texts = texts
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
return self.texts[idx]
texts = [
[1, 2, 3],
[4, 5],
[6, 7, 8, 9],
[10]
]
dataset = TextDataset(texts)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
for batch in dataloader:
print(batch)
输出:
Connected to pydev debugger (build 241.18034.82)
tensor([[1, 2, 3],
[4, 5, 0]])
tensor([[ 6, 7, 8, 9],
[10, 0, 0, 0]])
Process finished with exit code 0
torch.multiprocessing.spawn()
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
batch_size = 32
learning_rate = 0.01
epochs = 3
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(28 * 28, 10)
def forward(self, x):
return self.fc(x.view(-1, 28 * 28))
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size):
setup(rank, world_size)
model = SimpleModel().to(rank)
model = DDP(model, device_ids=[rank])
criterion = nn.CrossEntropyLoss().to(rank)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
for epoch in range(epochs):
train_sampler.set_epoch(epoch)
running_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(rank), target.to(rank)
outputs = model(data)
loss = criterion(outputs, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 100 == 0:
print(
f"Rank {rank}, Epoch [{epoch + 1}/{epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")
cleanup()
def main():
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
main()
输出:
Rank 0, Epoch [1/3], Step [0/938], Loss: 2.4825
Rank 1, Epoch [1/3], Step [0/938], Loss: 2.7278
Rank 1, Epoch [1/3], Step [100/938], Loss: 0.7033
Rank 0, Epoch [1/3], Step [100/938], Loss: 0.8019
Rank 1, Epoch [1/3], Step [200/938], Loss: 0.7423
Rank 0, Epoch [1/3], Step [200/938], Loss: 0.4223
Rank 1, Epoch [1/3], Step [300/938], Loss: 0.2833
Rank 0, Epoch [1/3], Step [300/938], Loss: 0.5531
Rank 1, Epoch [1/3], Step [400/938], Loss: 0.3288
......
python hook
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
model = SimpleModel()
def forward_hook(module, input, output):
print(f"Hook: 模块 {module} 的输入 {input},输出 {output}")
hook_handle = model.fc1.register_forward_hook(forward_hook)
input_data = torch.randn(1, 10)
output = model(input_data)
print(f"最终输出: {output}")
hook_handle.remove()
输出:
Hook: 模块 Linear(in_features=10, out_features=5, bias=True) 的输入 (tensor([[ 0.7551, 1.3895, 0.4566, 0.5799, -1.2487, 0.1824, 0.8438, 1.0473,
-0.7047, -0.2592]]),),输出 tensor([[ 0.2386, -0.1723, 0.2275, 0.0239, 0.1021]],
grad_fn=<AddmmBackward0>)
最终输出: tensor([[ 0.0837, -0.1774]], grad_fn=<AddmmBackward0>)
异步保存或者中止训练
import multiprocessing
import time
import torch
def train_model(signal_queue):
epoch = 0
try:
while True:
print(f"Training... Epoch: {epoch}")
time.sleep(1)
if not signal_queue.empty():
signal = signal_queue.get()
if signal == "SAVE":
print(f"Saving model at epoch {epoch}")
model = {"epoch": epoch}
torch.save(model, f"model_epoch_{epoch}.pt")
print(f"Model saved at epoch {epoch}")
elif signal == "STOP":
print("Stopping training...")
break
epoch += 1
if epoch > 20:
print("Reached maximum epochs.")
break
except KeyboardInterrupt:
print("Training interrupted manually.")
finally:
print("Training finished.")
def monitor_and_send_signal(signal_queue):
try:
while True:
command = input("Enter 'save' to save the model or 'stop' to stop training: ").strip()
if command.lower() == 'save':
signal_queue.put("SAVE")
elif command.lower() == 'stop':
signal_queue.put("STOP")
break
except KeyboardInterrupt:
print("Monitoring interrupted manually.")
finally:
print("Monitoring finished.")
def main():
signal_queue = multiprocessing.Queue()
training_process = multiprocessing.Process(target=train_model, args=(signal_queue,))
training_process.start()
monitor_and_send_signal(signal_queue)
training_process.join()
if __name__ == "__main__":
main()
装饰器方式异步保存或者中止训练
import time
import threading
import multiprocessing
def monitor_decorator(func):
def wrapper(*args, **kwargs):
signal_queue = kwargs.get('signal_queue')
if signal_queue is None:
raise ValueError("signal_queue is required as a keyword argument")
stop_event = threading.Event()
def check_signals():
while not stop_event.is_set():
if not signal_queue.empty():
signal = signal_queue.get()
if signal == 'STOP':
print("Stopping the monitored function.")
stop_event.set()
elif signal == 'SAVE':
print("Saving current state...")
time.sleep(0.1)
signal_thread = threading.Thread(target=check_signals)
signal_thread.start()
try:
result = func(*args, **kwargs, stop_event=stop_event)
finally:
stop_event.set()
signal_thread.join()
return result
return wrapper
@monitor_decorator
def long_running_function(*args, stop_event=None, **kwargs):
epoch = 0
while not stop_event.is_set():
print(f"Running epoch {epoch}...")
time.sleep(1)
epoch += 1
if epoch > 50:
print("Reached maximum epochs.")
break
print("Function completed.")
def monitor_input(signal_queue):
while True:
command = input("Enter 'save' to save the state or 'stop' to stop execution: ").strip().lower()
if command == 'save':
signal_queue.put('SAVE')
elif command == 'stop':
signal_queue.put('STOP')
break
if __name__ == "__main__":
signal_queue = multiprocessing.Queue()
process = multiprocessing.Process(target=long_running_function, kwargs={'signal_queue': signal_queue})
process.start()
monitor_input(signal_queue)
process.join()
简单的数据加载器SimpleDataLoader
import random
import time
class SimpleDataLoader:
def __init__(self, data, batch_size=1, shuffle=False, curriculum_learning_enabled=False, post_process_func=None):
"""
初始化数据加载器
:param data: 数据集 (list, numpy array, etc.)
:param batch_size: 每批数据的大小
:param shuffle: 是否在每次迭代时随机打乱数据
:param curriculum_learning_enabled: 是否启用课程学习
:param post_process_func: 后处理函数(如果有)
"""
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.curriculum_learning_enabled = curriculum_learning_enabled
self.post_process_func = post_process_func
self.len = len(data) // batch_size
self.current_index = 0
self.data_iterator = None
def _create_dataloader(self):
"""创建数据迭代器,根据是否启用打乱和课程学习来决定如何构造"""
if self.shuffle:
random.shuffle(self.data)
self.data_iterator = iter(self.data)
def __iter__(self):
"""初始化迭代器"""
self.current_index = 0
self._create_dataloader()
return self
def __len__(self):
"""返回批次数量"""
return self.len
def __next__(self):
"""返回下一个数据批次"""
if self.current_index >= len(self.data):
raise StopIteration
batch = []
try:
for _ in range(self.batch_size):
batch.append(next(self.data_iterator))
self.current_index += self.batch_size
except StopIteration:
raise StopIteration
if self.curriculum_learning_enabled and self.post_process_func is not None:
batch = self.post_process_func(batch)
return batch
def post_process(batch):
return [x * 2 for x in batch]
if __name__ == '__main__':
data = [i for i in range(1, 21)]
data_loader = SimpleDataLoader(data, batch_size=6, shuffle=False)
for batch in data_loader:
print(f"Batch: {batch}")
输出:
Batch: [1, 2, 3, 4, 5, 6]
Batch: [7, 8, 9, 10, 11, 12]
Batch: [13, 14, 15, 16, 17, 18]
带缓存队列的数据加载器SimpleDataLoader
import random
import time
import threading
import queue
class SimpleDataLoaderWithCache:
def __init__(self, data, batch_size=1, shuffle=False, curriculum_learning_enabled=False, post_process_func=None,
cache_size=10):
"""
初始化数据加载器
:param data: 数据集 (list, numpy array, etc.)
:param batch_size: 每批数据的大小
:param shuffle: 是否在每次迭代时随机打乱数据
:param curriculum_learning_enabled: 是否启用课程学习
:param post_process_func: 后处理函数(如果有)
:param cache_size: 缓存队列的最大大小
"""
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.curriculum_learning_enabled = curriculum_learning_enabled
self.post_process_func = post_process_func
self.len = len(data) // batch_size
self.cache_size = cache_size
self.data_iterator = None
self.cache = queue.Queue(maxsize=self.cache_size)
self.stop_event = threading.Event()
def _create_dataloader(self):
"""创建数据迭代器,根据是否启用打乱和课程学习来决定如何构造"""
if self.shuffle:
random.shuffle(self.data)
self.data_iterator = iter(self.data)
def __iter__(self):
"""初始化迭代器"""
self.current_index = 0
self._create_dataloader()
self.data_thread = threading.Thread(target=self._fill_cache)
self.data_thread.daemon = True
self.data_thread.start()
return self
def __len__(self):
"""返回批次数量"""
return self.len
def _fill_cache(self):
"""后台线程:向缓存队列中生成批次数据"""
while not self.stop_event.is_set():
if self.cache.full():
time.sleep(0.1)
continue
batch = []
try:
for _ in range(self.batch_size):
batch.append(next(self.data_iterator))
self.current_index += self.batch_size
except StopIteration:
break
if self.curriculum_learning_enabled and self.post_process_func is not None:
batch = self.post_process_func(batch)
self.cache.put(batch)
def __next__(self):
"""从缓存队列中获取下一个批次"""
if self.cache.empty() and not self.data_thread.is_alive():
raise StopIteration
batch = self.cache.get()
return batch
def close(self):
"""关闭生成器,停止数据生成"""
self.stop_event.set()
self.data_thread.join()
def post_process(batch):
return [x * 2 for x in batch]
if __name__ == "__main__":
data = [i for i in range(1, 51)]
data_loader = SimpleDataLoaderWithCache(data, batch_size=12, shuffle=False, curriculum_learning_enabled=True,
post_process_func=post_process, cache_size=3)
for batch in data_loader:
print(f"Batch: {batch}")
time.sleep(0.5)
data_loader.close()
输出:
Batch: [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24]
Batch: [26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48]
Batch: [50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72]
Batch: [74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96]
装饰器
import functools
def log_function_call(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
print(f"Calling function {func.__name__}")
result = func(*args, **kwargs)
print(f"Function {func.__name__} finished")
return result
return wrapper
@log_function_call
def say_hello(name):
print(f"Hello, {name}!")
say_hello("Alice")
使用装饰器的方式记录log
import torch
import torch.nn as nn
import torch.optim as optim
import logging
import functools
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("training_with_decorator.log"),
logging.StreamHandler()
]
)
def log_forward(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
logging.info(f"Forward pass: Inputs: {args[1:]}")
output = func(*args, **kwargs)
logging.info(f"Forward pass: Output: {output}")
return output
return wrapper
def log_loss(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
loss = func(*args, **kwargs)
logging.info(f"Loss computed: {loss.item()}")
return loss
return wrapper
def log_backward(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
logging.info("Backward pass started")
result = func(*args, **kwargs)
logging.info("Backward pass completed")
return result
return wrapper
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
@log_forward
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = log_loss(nn.MSELoss())
@log_backward
def backward_and_step(loss):
optimizer.zero_grad()
loss.backward()
optimizer.step()
inputs = torch.randn(10, 10)
labels = torch.randn(10, 1)
for epoch in range(5):
logging.info(f"Epoch {epoch + 1} started")
outputs = model(inputs)
loss = criterion(outputs, labels)
backward_and_step(loss)
logging.info(f"Epoch {epoch + 1} completed\n")
logging.info("Training finished")