这篇博客用的代码:
import json
import os
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from threading import Lock
from typing import Dict, Tuple, Union, Any
import torch
from builder import parser_server, parser_clients
from tools.logger import Logger
from tools.utils import clear_cache, same_seeds
class ExperimentLog(object):
def __init__(self, save_path: str):
self.records = {}
self.save_path = save_path
self.lock = Lock()
def _update_iter(self, key, value):
keys = key.split('.')
current_record = self.records
for idx, key in enumerate(keys):
if idx != len(keys) - 1:
if key not in current_record.keys():
current_record[key] = {}
current_record = current_record[key]
else:
if key not in current_record.keys():
current_record[key] = value
else:
if isinstance(current_record[key], list):
current_record[key].append(value)
elif isinstance(current_record[key], set):
current_record[key].add(value)
elif isinstance(current_record[key], dict):
current_record[key].update(value)
else:
current_record[key] = value
def _save_logs(self):
dirname = os.path.dirname(self.save_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(self.save_path, "w") as f:
json.dump(self.records, f, indent=2)
def record(self, key, value):
self.lock.acquire()
self._update_iter(key, value)
self._save_logs()
self.lock.release()
class VirtualContainer(object):
def __init__(self, devices: list, parallel: int = 1) -> None:
super().__init__()
self.lock = Lock()
self.devices = {device: parallel for device in devices}
def max_worker(self):
return sum(self.devices.values())
def acquire_device(self, count=1):
device = None
self.lock.acquire()
for dev, cnt in self.devices.items():
if cnt > 0 and device is None:
self.devices[dev] -= count
device = dev
self.lock.release()
return device
def release_device(self, device, count=1):
self.lock.acquire()
self.devices[device] += count
self.lock.release()
def possess_device(self, count=1):
class VirtualProcess(object):
def __init__(self, container) -> None:
super().__init__()
self.container = container
self.device = None
def __enter__(self):
self.device = self.container.acquire_device(count)
return self.device
def __exit__(self, type, value, trace):
self.container.release_device(self.device, count)
return
return VirtualProcess(self)
class ExperimentStage(object):
def __init__(self, common_config: Dict, exp_configs: Union[Dict, Tuple[Dict]]):
self.common_config = common_config
self.exp_configs = [exp_configs] if isinstance(exp_configs, Dict) else exp_configs
self.logger = Logger('stage')
self.container = VirtualContainer(self.common_config['device'], self.common_config['parallel'])
def __enter__(self):
self.check_environment()
return self
def __exit__(self, type, value, trace):
if type is not None and issubclass(type, Exception):
self.logger.error(value)
raise trace
return self
def check_environment(self):
# check runtime device
devices = self.common_config['device']
for device in devices:
try:
torch.Tensor([0]).to(device)
except Exception as ex:
self.logger.error(f'Not available for given device {device}:{ex}')
exit(1)
# check dataset base path
datasets_dir = self.common_config['datasets_dir']
if not os.path.exists(datasets_dir):
self.logger.error(f'Datasets base directory could not be found with {datasets_dir}.')
exit(1)
# check dataset base path
checkpoints_dir = self.common_config['checkpoints_dir']
if os.path.exists(checkpoints_dir):
self.logger.warn(f'Checkpoint directory {checkpoints_dir} is not empty.')
self.logger.info('Experiment stage build success.')
def run(self):
print(self.exp_configs)
for exp_config in self.exp_configs:
same_seeds(exp_config['random_seed'])
# generate log with time-based savepath
format_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
# 参数为文件路径,保存到logs_dir: ./logs/ exp_name: fedstil
log = ExperimentLog(os.path.join(
self.common_config['logs_dir'],
f"{exp_config['exp_name']}-{format_time}.json"
))
log.record('config', exp_config)
self.logger.info(f"Experiment loading succeed: {exp_config['exp_name']}")
self.logger.info(f"For more details: {log.save_path}")
# generate server and clients
server = parser_server(exp_config, self.common_config)
clients = parser_clients(exp_config, self.common_config)
# initial validate for tasks
with ThreadPoolExecutor(self.container.max_worker()) as pool:
futures = []
for client in clients:
print(client)
futures.append(pool.submit(
self._process_val,
*(client, log, 0, self.container)
))
for future in as_completed(futures):
future.result(timeout=1800)
if future.exception():
raise future.exception()
# simulate communication process
comm_rounds = int(exp_config['exp_opts']['comm_rounds'])
for curr_round in range(1, comm_rounds + 1):
self.logger.info(f'Start communication round: {curr_round:0>3d}/{comm_rounds:0>3d}')
self._process_one_round(curr_round, server, clients, exp_config, log)
del server, clients, log
def _process_one_round(self, curr_round, server, clients, exp_config, log) -> Any:
# sample online clients
online_clients = random.sample(clients, exp_config['exp_opts']['online_clients'])
val_intervals = exp_config['exp_opts']['val_interval']
# update clients with server state
for client in online_clients:
if client.client_name not in server.clients.keys():
server.register_client(client.client_name)
dispatch_state = server.get_dispatch_integrated_state(client.client_name)
if dispatch_state is not None:
client.update_by_integrated_state(dispatch_state)
else:
dispatch_state = server.get_dispatch_incremental_state(client.client_name)
if dispatch_state is not None:
client.update_by_incremental_state(dispatch_state)
server.save_state(
f'{curr_round}-{server.server_name}-{client.client_name}',
dispatch_state, True
)
del dispatch_state
# simulate training for each online client
with ThreadPoolExecutor(self.container.max_worker()) as pool:
futures = []
for client in online_clients:
futures.append(pool.submit(
self._process_train,
*(client, log, curr_round, self.container)
))
for future in as_completed(futures):
future.result(timeout=1800)
if future.exception():
raise future.exception()
# simulate validation for each client
if curr_round % val_intervals == 0:
with ThreadPoolExecutor(self.container.max_worker()) as pool:
futures = []
for client in clients:
futures.append(pool.submit(
self._process_val,
*(client, log, curr_round, self.container)
))
for future in as_completed(futures):
future.result(timeout=1800)
if future.exception():
raise future.exception()
# communication with server
for client in online_clients:
incremental_state = client.get_incremental_state()
client.save_state(
f'{curr_round}-{client.client_name}-{server.server_name}',
incremental_state, True
)
if incremental_state is not None:
server.set_client_incremental_state(client.client_name, incremental_state)
del incremental_state
server.calculate()
@staticmethod
@clear_cache
def _process_train(client, log, curr_round, container):
with container.possess_device() as device:
try:
task_pipeline = client.task_pipeline
task = task_pipeline.next_task()
if task['tr_epochs'] != 0:
tr_output = client.train(
epochs=task['tr_epochs'],
task_name=task['task_name'],
tr_loader=task['tr_loader'],
val_loader=task['query_loader'],
device=device
)
log.record(f"data.{client.client_name}.{curr_round}.{task['task_name']}", {
"tr_acc": tr_output['accuracy'],
"tr_loss": tr_output['loss']
})
except Exception as ex:
client.logger.error(ex)
raise ex
@staticmethod
@clear_cache
def _process_val(client, log, curr_round, container):
with container.possess_device(container.max_worker()) as device:
try:
task_pipeline = client.task_pipeline
for tid in range(len(task_pipeline.task_list)):
task = task_pipeline.get_task(tid)
cmc, mAP, avg_rep = client.validate(
task_name=task['task_name'],
query_loader=task['query_loader'],
gallery_loader=task['gallery_loaders'],
device=device
)
log.record(f"data.{client.client_name}.{curr_round}.{task['task_name']}", {
"val_rank_1": cmc[0],
"val_rank_3": cmc[2],
"val_rank_5": cmc[4],
"val_rank_10": cmc[9],
"val_map": mAP,
})
except Exception as ex:
client.logger.error(ex)
raise ex
1.Python时间模块之datetime模块
from datetime import datetime
format_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
格式化时间,格式参照time模块中的strftime方法
from datetime import datetime
format_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
format_time2 = datetime.now()
print(format_time)
print(format_time2)
2.python中os库的使用
log = ExperimentLog(os.path.join(
self.common_config['logs_dir'],
f"{exp_config['exp_name']}-{format_time}.json"
))
上面这个代码中common_config是读取的comm_config.yaml文件,使用其中的log_dir的值:
python字符串前面加f是什么意思,如何表达式嵌入字符串中
然后与另外一个字符串进行拼接。
最后组成的是这样的文件名:
作为参数调用ExperimentLog方法,最终的效果是
将相关信息保存在同级目录的logs文件夹内:
-
os.path.join(path,*paths):组合path和paths,返回一个路径字符串
import os
print(os.path.join("D:","123"))
-
os.path.dirname(path):返回path中的目录名称
-
os.path.exists(path):判断path对应文件或目录是否存在,返回True或False
def _save_logs(self):
dirname = os.path.dirname(self.save_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(self.save_path, "w") as f:
json.dump(self.records, f, indent=2)
3.python 线程-- 锁
# generate log with time-based savepath
format_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
# 参数为文件路径,保存到logs_dir: ./logs/ exp_name: fedstil
log = ExperimentLog(os.path.join(
self.common_config['logs_dir'],
f"{exp_config['exp_name']}-{format_time}.json"
))
log.record('config', exp_config)
class ExperimentLog(object):
def __init__(self, save_path: str):
self.records = {}
self.save_path = save_path
self.lock = Lock()
def _update_iter(self, key, value):
keys = key.split('.')
current_record = self.records
for idx, key in enumerate(keys):
if idx != len(keys) - 1:
if key not in current_record.keys():
current_record[key] = {}
current_record = current_record[key]
else:
if key not in current_record.keys():
current_record[key] = value
else:
if isinstance(current_record[key], list):
current_record[key].append(value)
elif isinstance(current_record[key], set):
current_record[key].add(value)
elif isinstance(current_record[key], dict):
current_record[key].update(value)
else:
current_record[key] = value
def _save_logs(self):
dirname = os.path.dirname(self.save_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(self.save_path, "w") as f:
json.dump(self.records, f, indent=2)
def record(self, key, value):
self.lock.acquire()
self._update_iter(key, value)
self._save_logs()
self.lock.release()
这段代码就是2中实例化的ExperimentLog类,调用这个类的record()方法,其中用到了线程的锁。
互斥锁 Lock
线程同步能够保证多个线程安全访问竞争资源,最简单的同步机制是引入互斥锁。互斥锁为资源设置一个状态:锁定和非锁定。某个线程要更改共享数据时,先将其锁定,此时资源的状态为“锁定”,其他线程不能更改;直到该线程释放资源,将资源的状态变成“非锁定”,其他的线程才能再次锁定该资源。互斥锁保证了每次只有一个线程进行写入操作,从而保证了多线程情况下数据的正确性。
简单举例:
import time
import threading
from datetime import datetime
def fun(lock,cnt):
re = lock.acquire(timeout=3) # 默认阻塞线程,直到超时
if not re:
print('get lock failed')
return
global num
temp = num
time.sleep(0.2)
temp -= 1
num = temp
print(f"{cnt}:{datetime.now()}")
lock.release()
print('主线程开始运行……')
t_lst = []
num = 10 # 全局变量
lock = threading.Lock()
for i in range(10):
t = threading.Thread(target=fun, args=(lock,i))
t_lst.append(t)
t.start()
[t.join() for t in t_lst]
print(f"main:{datetime.now()}")
print('num最后的值为:{}'.format(num))
print('主线程结束运行……')
这段举例代码中的重点是join()函数与互斥锁 Lock。
首先代码开始运行:
我们想要实现的效果为10个子线程依次修改num的值,每个子线程都num-=1,最终在主线程输出num,预计为0。
t_lst用于保存线程对象。
创建lock对象,然后进行10次循环,实现10个子线程,每个子线程都调用fun方法。
输出符合预期,先运行子线程,子线程全部结束后,输出主线程的时间:
但是如果删除join()后:
删除这行:
输出结果显示主线程先运行,输出num=10的时候子线程还没有开始。
join方法
join()方法的作用是在调用join()方法处,让所在线程(主线程)同步的等待被join的线程,等到join的线程结束后才执行当前所在线程。
join()方法需要使用在start()函数后。
4.Python中logging模块
self.logger = Logger('stage')
import logging
# 设置输出格式
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(name)s]: %(levelname)s - %(message)s')
class Logger(object):
def __init__(self, actuator: str = 'unknown'):
# 实例化一个logging对象,name是记录日志的用例名
self.logger = logging.getLogger(actuator)
def debug(self, msg: str) -> None:
self.logger.debug(msg)
def info(self, msg: str) -> None:
self.logger.info(msg)
def warn(self, msg: str) -> None:
self.logger.warning(msg)
def error(self, msg: str) -> None:
self.logger.error(msg)
def info_train(self, task_name, device, train_cnt, accuracy, loss, current_epoch=0, total_epoch=0):
self.logger.info(
(f"[{current_epoch:0>3d}/{total_epoch:0>3d}] " if current_epoch and total_epoch else f"") +
f"Train '{task_name}' on {device} with {train_cnt:,} images, " +
f"accuracy: {accuracy:.2%}, loss: {loss:.4f}."
)
def info_validation(self, task_name, query_cnt, gallery_cnt, cmc, mAP) -> None:
self.logger.info(
"""Validation '{}' with {:,} query images on {:,} gallery images:
|- Rank-1 : {:.2%}
|- Rank-3 : {:.2%}
|- Rank-5 : {:.2%}
|- Rank-10 : {:.2%}
|- mean AP : {:.2%}
""".format(task_name, query_cnt, gallery_cnt, cmc[0], cmc[2], cmc[4], cmc[9], mAP)
)
logging.getLogger(name=‘root’)
实例化一个logging对象,name是记录日志的用例名