降水临近预报_Weather4cast_RainAI代码分享
主程序w4c23
def main():
parser = set_parser()
options = parser.parse_args()
params = update_params_based_on_args(options)
selected_model = params["model"]["model_name"]
if selected_model == "2D_UNET_base":
model = UNetModule
elif selected_model == "SWIN":
model = SWINModule
train(params, options.gpus, options.mode, options.checkpoint, model)
set_parser()
是一个函数,用于设置和返回一个argparse.ArgumentParser
对象
parser.parse_args()
方法来解析命令行参数并将结果存储在options
变量中
def update_params_based_on_args(options):
config_p = os.path.join("configurations", options.config_path)
params = load_config(config_p)
if options.name != "":
print(params["experiment"]["name"])
params["experiment"]["name"] = options.name
if options.epochs is not None:
params["train"]["max_epochs"] = options.epochs
if options.batch_size is not None:
params["train"]["batch_size"] = options.batch_size
if options.num_workers is not None:
params["train"]["n_workers"] = options.num_workers
if options.input_path != "":
params["dataset"]["data_root"] = options.input_path
if options.output_path != "":
params["experiment"]["experiment_folder"] = options.output_path
if options.region_to_predict != "":
params["predict"]["region_to_predict"] = options.region_to_predict
if options.year_to_predict != "":
params["predict"]["year_to_predict"] = options.year_to_predict
if options.submission_out_dir != "":
params["predict"]["submission_out_dir"] = options.submission_out_dir
return params
models
baseModule
具有强度输出和概率输出的模型的基本模块。需要验证和预测实现的抽象类。
BaseModule
的类,它继承自LightningModule
和ABC
。
因为继承自LightningModule
,要重写training_step
、validation_step
、predict_step
、configure_optimizers
方法,详见后续。
ABC
(Abstract Base Class)是一个用于定义抽象基类的元类。抽象基类是不能被实例化的类,它主要用于定义接口和共享方法的规范。通过继承抽象基类,子类需要实现抽象基类中定义的抽象方法
,以满足基类的接口规范。抽象基类可以提供一种约束,确保子类的一致性和可替换性。
if self.probabilistic:
# Store bucket means (but not as model parameter) as the channel dimension of the data
self.register_buffer(
"bucket_means",
torch.tensor(self.buckets.means).view(1, -1, 1, 1, 1),
)
self.bucket_means: torch.Tensor
如果损失函数是概率型的(probabilistic=True
),则代码会使用self.register_buffer
方法注册一个缓冲区(buffer)bucket_means
,用于存储损失函数的桶均值。这里使用torch.tensor
将桶均值转换成张量,并通过view
方法对其进行形状变换,以便后续使用。需要注意的是,注册的缓冲区不会作为模型的参数进行优化。
if model_params["upsample"] == "bilinear":
self.upsample = BilinearUpsample(42, 252, self.forecast_length)
elif model_params["upsample"] == "nearest":
self.upsample = NearestUpsample(42, 252, self.forecast_length)
elif model_params["upsample"] == "ninasr":
self.upsample = NinaSRUpsample(
42, 252, self.forecast_length, self.num_classes
)
elif model_params["upsample"] == "edsr":
self.upsample = EDSRUpsample(
42, 252, self.forecast_length, self.num_classes
)
else:
self.upsample = None
根据model_params["upsample"]
的值选择相应的上采样方法对象赋值给self.upsample
。根据代码片段提供的信息,上采样方法可以是BilinearUpsample
、NearestUpsample
、NinaSRUpsample
或EDSRUpsample
。如果model_params["upsample"]
的值不在这些选项中,self.upsample
将被设置为None
。
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
abstractmethod
是一个装饰器,用于定义抽象方法。抽象方法是在抽象基类中声明但没有实现的方法,它只有方法的声明部分,没有具体的方法体。抽象方法必须在子类中被重写实现,否则子类也会成为抽象类。通过使用abstractmethod
装饰器,可以明确地表示某个方法是抽象方法。
在这段代码中,forward
方法被定义为抽象方法,即没有具体的实现。抽象方法使用abstractmethod
装饰器进行修饰,表示它是一个需要在子类中被重写实现的方法。子类必须提供forward
方法的具体实现,以满足抽象基类的接口规范。
def augment_batch(self, batch):
"""Apply augmentation on training batches (flips and 90-degrees rotation)"""
# TODO - Change to data loader
if not self.transform:
return batch
input, label, metadata = batch
angle = random.choice([-90, 0, 90, 180])
transformations = [
v2.RandomHorizontalFlip(),
v2.RandomVerticalFlip(),
v2.RandomRotation([angle, angle]),
]
t = random.choice(transformations)
input = t(input).contiguous()
label = t(label).contiguous()
# Transform masks
metadata["input"]["mask"] = t(metadata["input"]["mask"])
metadata["target"]["mask"] = t(metadata["target"]["mask"])
# Transform static data if any
if self.static_data:
metadata["input"]["topo"] = t(metadata["input"]["topo"])
metadata["target"]["topo"] = t(metadata["target"]["topo"])
metadata["input"]["lat-long"] = t(metadata["input"]["lat-long"])
metadata["target"]["lat-long"] = t(metadata["target"]["lat-long"])
return input, label, metadata
augment_batch
方法接受一个batch
参数,表示训练批次数据。该方法的作用是对训练批次数据进行增强操作,包括翻转和旋转。增强操作可以提高模型的鲁棒性和泛化能力,使其能够更好地适应不同的输入样本。
在当前的实现中,首先判断是否需要进行数据增强操作,如果self.transform
为False
,则直接返回原始的批次数据。否则,从批次数据中获取输入、标签和元数据。然后,随机选择一个角度(-90度、0度、90度或180度),并定义一些变换操作,包括随机水平翻转、随机垂直翻转和随机旋转。接下来,从变换操作中随机选择一个变换t
,并将其应用于输入、标签和元数据的对应部分。其中,输入和标签通过调用变换对象的__call__
方法进行转换,并使用contiguous
方法保证数据的连续性。对于元数据中的掩码(mask)数据和静态数据(如果有的话),也需要进行相应的变换操作。最后,返回经过增强操作后的输入、标签和元数据。
def add_static(self, input, metadata):
lat_long = (
metadata["input"]["lat-long"]
.unsqueeze(2)
.repeat(1, 1, self.history_length, 1, 1)
)
topo = (
metadata["input"]["topo"]
.unsqueeze(2)
.repeat(1, 1, self.history_length, 1, 1)
)
input = torch.cat([input, lat_long, topo], dim=1)
return input
add_static
方法接受两个参数,input
表示输入数据,metadata
表示元数据。该方法的作用是将静态数据(lat_long
和topo
)添加到输入数据中。
首先,代码从metadata
中获取了lat_long
和topo
数据。这些数据可能是二维张量,表示地理坐标和地形信息。然后,通过使用unsqueeze
方法在适当的维度上添加一个维度,以便进行重复复制。使用repeat
方法将lat_long
和topo
在相应的维度上进行重复,以匹配输入数据的形状。接下来,使用torch.cat
方法将输入数据、lat_long
和topo
在维度1上进行连接,将它们合并成一个更大的输入张量。最后,返回合并后的输入数据。
def training_step(self, batch):
batch = self.augment_batch(batch)
input, label, metadata = batch
# Add static data to input if required
if self.static_data:
input = self.add_static(input, metadata)
input = self.transform_input(input)
prediction = self.forward(input)
if self.upsample:
prediction = self.upsample(prediction)
mask = metadata["target"]["mask"]
loss = self.loss_fn(prediction, label, mask)
self.log("train/loss", loss, sync_dist=True)
return loss
training_step
方法接受一个batch
参数,表示训练批次数据。该方法的作用是执行一次训练步骤,包括数据增强、添加静态数据、输入转换、模型前向传播、上采样、计算损失和记录训练损失。
首先,代码调用augment_batch
方法对批次数据进行增强操作,得到增强后的批次数据。然后,从增强后的批次数据中获取输入、标签和元数据。接下来,根据是否需要添加静态数据的设置,判断是否需要将静态数据添加到输入中。如果需要添加静态数据,调用add_static
方法将静态数据添加到输入数据中,得到添加了静态数据的输入。然后,调用transform_input
方法对输入数据进行转换,得到转换后的输入数据。接着,调用forward
方法对转换后的输入数据进行模型的前向传播,得到预测结果。如果定义了上采样方法(self.upsample
不为None
),则对预测结果进行上采样操作。接下来,从元数据中获取目标数据的掩码(mask)。然后,使用损失函数self.loss_fn
计算预测结果与标签之间的损失,传入预测结果、标签和掩码作为参数。最后,使用self.log
方法记录训练损失,并返回损失值。
def validation_step(self, batch, batch_idx) -> ValidationOutput:
input, label, metadata = batch
# Add static data to input if required
if self.static_data:
input = self.add_static(input, metadata)
input = self.transform_input(input)
prediction = self.forward(input)
if self.upsample:
prediction = self.upsample(prediction)
mask = metadata["target"]["mask"]
loss = self.loss_fn(prediction, label, mask)
self.log("val/loss", loss, sync_dist=True)
if self.probabilistic:
# If no softmax, apply as it is required for the metrics (i.e. CRPS)
if self.activation == "none":
prediction = nn.functional.softmax(prediction, dim=1)
probabilities = prediction
intensity = self.integrate(prediction)
else:
probabilities = None
intensity = prediction
return ValidationOutput(intensity=intensity, probabilities=probabilities)
validation_step
方法接受两个参数,batch
表示验证批次数据,batch_idx
表示批次索引。该方法的作用是执行一次验证步骤,包括添加静态数据、输入转换、模型前向传播、上采样、计算损失、记录验证损失和返回验证输出。
首先,代码从验证批次数据中获取输入、标签和元数据。接下来,根据是否需要添加静态数据的设置,判断是否需要将静态数据添加到输入中。如果需要添加静态数据,调用add_static
方法将静态数据添加到输入数据中,得到添加了静态数据的输入。然后,调用transform_input
方法对输入数据进行转换,得到转换后的输入数据。接着,调用forward
方法对转换后的输入数据进行模型的前向传播,得到预测结果。如果定义了上采样方法(self.upsample
不为None
),则对预测结果进行上采样操作。接下来,从元数据中获取目标数据的掩码(mask)。然后,使用损失函数self.loss_fn
计算预测结果与标签之间的损失,传入预测结果、标签和掩码作为参数。接着,使用self.log
方法记录验证损失,并传入"val/loss"
作为日志名称,loss
作为损失值,并设置sync_dist=True
以确保在分布式训练中同步日志。如果模型的损失函数是概率型的(self.probabilistic=True
),则进行一些额外的操作。首先,如果激活函数是"none"(即没有使用激活函数),则将预测结果进行 softmax 操作,因为一些指标(如 CRPS)需要概率分布的预测结果。然后,将预测结果作为概率分布probabilities
,并将预测结果进行积分得到intensity
。最后,返回一个ValidationOutput
对象,包含intensity
和probabilities
。
def predict_step(self, batch, batch_idx=None) -> torch.Tensor:
input, _, metadata = batch
# Add static data to input if required
if self.static_data:
input = self.add_static(input, metadata)
input = self.transform_input(input)
prediction = self.forward(input)
if self.upsample:
prediction = self.upsample(prediction)
if self.probabilistic:
# If no softmax, apply as it to sum 1
if self.activation == "none":
prediction = nn.functional.softmax(prediction, dim=1)
probabilities = prediction
intensity = self.integrate(prediction)
else:
probabilities = None
intensity = prediction
intensity = intensity[:, :, : self.forecast_length, :, :]
return intensity
首先,代码从预测批次数据中获取输入数据和元数据,忽略了标签数据(_
)。接下来,根据是否需要添加静态数据的设置,判断是否需要将静态数据添加到输入中。如果需要添加静态数据,调用add_static
方法将静态数据添加到输入数据中,得到添加了静态数据的输入。然后,调用transform_input
方法对输入数据进行转换,得到转换后的输入数据。接着,调用forward
方法对转换后的输入数据进行模型的前向传播,得到预测结果。如果定义了上采样方法(self.upsample
不为None
),则对预测结果进行上采样操作。如果模型的损失函数是概率型的(self.probabilistic=True
),则进行一些额外的操作。首先,如果激活函数是"none"(即没有使用激活函数),则将预测结果进行 softmax 操作,以确保预测结果的和为1。然后,将预测结果作为概率分布probabilities
,并将预测结果进行积分得到intensity
。最后,根据预测长度截取intensity
中的相应部分,并返回截取后的intensity
作为预测结果。
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
return optimizer
使用了torch.optim.AdamW
优化器类来创建一个AdamW优化器对象。AdamW
是Adam优化器的一种变体,它在优化过程中引入了权重衰减(weight decay)的正则化项,有助于控制模型的复杂度并提高泛化能力。在创建优化器对象时,传入了两个参数。self.parameters()
表示要优化的模型参数,即模型中所有需要进行梯度更新的参数。lr=self.lr
和weight_decay=self.weight_decay
分别指定了学习率和权重衰减的数值,这些数值是在模型初始化时从参数中获取的。最后,将创建的优化器对象返回。
losses
交叉熵和均方误差计算,对应概率输出和强度输出。
callbacks
callbacks文件夹应该放回调代码就可以了,不知道为什么把metrics代码也放这里。
log
用于在PyTorch Lightning框架中记录和计算各种指标(metrics)的值
init
def __init__(self, num_leadtimes, probabilistic, buckets, logging):
super().__init__()
self.num_leadtimes = num_leadtimes
self.probabilistic = probabilistic
if buckets != "none":
self.buckets = BUCKET_CONSTANTS[buckets]
else:
self.buckets = None
self.logging = logging
self.thresholds = [0.2, 1, 5, 10, 15]
接收参数num_leadtimes(leading time steps)、probabilistic(是否概率性指标)、buckets(用于概率性指标的桶大小)、logging(指标记录的方式)。
- num_leadtimes:leading time steps。
- probabilistic:一个布尔值,表示是否使用概率性指标。
- buckets:一个字符串,表示概率性指标中用于分桶的参数。如果不需要分桶,则为"none"。
- logging:一个字符串,表示指标记录的方式,可以是"tensorboard"或"wandb"。
- thresholds:一个列表,包含阈值的值。这些阈值将用于计算关键成功指数(Critical Success Index,CSI)
from dataclasses import dataclass
from typing import List
@dataclass
class Bucket:
idx: int
mean: float
max: float
weight: float
@dataclass
class BucketConstants:
buckets: List[Bucket]
means: List[float]
weights: List[float]
boundaries: List[float]
ranges: List[float]
num_buckets: int
# Custom buckets used for classification when using mm/h
_buckets_mmh = [
Bucket(idx=0, mean=0, max=0.08, weight=0.5107),
Bucket(idx=1, mean=0.12, max=0.16, weight=0.6014),
Bucket(idx=2, mean=0.2, max=0.25, weight=0.627),
Bucket(idx=3, mean=0.32, max=0.4, weight=0.6295),
Bucket(idx=4, mean=0.51, max=0.63, weight=0.631),
Bucket(idx=5, mean=0.81, max=1, weight=0.6359),
Bucket(idx=6, mean=1.3, max=1.6, weight=0.6472),
Bucket(idx=7, mean=2.0, max=2.5, weight=0.6667),
Bucket(idx=8, mean=3.25, max=4, weight=0.6901),
Bucket(idx=9, mean=5.15, max=6.3, weight=0.7298),
Bucket(idx=10, mean=8.1, max=10, weight=0.7823),
Bucket(idx=11, mean=13, max=16, weight=0.8428),
Bucket(idx=12, mean=20.5, max=25, weight=0.9084),
Bucket(idx=13, mean=32.5, max=40, weight=0.9617),
Bucket(
idx=14, mean=45, max=128, weight=1.0
), # Max is 128 as defined by preprocessing
]
def getBucketObject(buckets_list):
return BucketConstants(
buckets=buckets_list,
means=[b.mean for b in buckets_list],
weights=[b.weight for b in buckets_list],
boundaries=[b.max for b in buckets_list[:-1]],
ranges=[
buckets_list[i].max - buckets_list[i - 1].max
if i > 0
else buckets_list[i].max
for i in range(len(buckets_list))
],
num_buckets=len(buckets_list),
)
BUCKET_CONSTANTS = {
"mmh": getBucketObject(_buckets_mmh),
"test": getBucketObject(_buckets_test),
"w4c23_1": getBucketObject(_buckets_w4c23_1),
"w4c23_2": getBucketObject(_buckets_w4c23_2),
}
创建和管理不同的桶(Bucket)对象,并将其存储在BUCKET_CONSTANTS字典中。通过调用getBucketObject函数,可以根据桶列表获取相应的BucketConstants对象。这样做的目的是为了方便地创建和使用不同的桶,并将其关联到特定的名称,以供其他代码使用。
dataclasses 模块提供了一个装饰器 @dataclass,用于方便地创建和操作数据类(data class),它自动为类的属性生成相应的方法(如构造函数、属性访问方法、比较方法等),使得创建和操作数据对象更加简洁和方便。
from dataclasses import dataclass
@dataclass
class Person:
name: str
age: int
occupation: str
# # Code for checking if a metric can be optimized
# check_forward_full_state_property(
# metrics.MeanSquaredError,
# input_args={
# "prediction": torch.Tensor([0.5, 2.5]),
# "label": torch.Tensor([1.0, 2.0]),
# "mask": torch.zeros([2], dtype=bool),
# },
# )
被注释掉的代码是用于检查一个指标是否可以进行优化的示例代码。它使用torchmetrics库中的check_forward_full_state_property函数来检查均方误差(MeanSquaredError)指标是否可以进行优化。函数的输入参数为一个字典,包含了预测值(prediction)、标签值(label)和掩码(mask)。通过检查指标的前向计算是否可以成功执行,可以确保指标的正确性和可用性。
_threshold_str
def _threshold_str(self, threshold):
"""Remove .0 and change . by -"""
return f"{threshold:g}".replace(".", "-")
该段代码定义了一个名为"_threshold_str"的私有方法,用于处理阈值(threshold)的字符串表示。
该方法接受一个阈值参数,将其转换为字符串表示。转换过程包括以下步骤:
- 使用"{threshold:g}"将阈值转换为一般格式的字符串表示,去除多余的零和小数点。
- 使用.replace(“.”, “-”)将字符串中的小数点替换为短横线。
“g” 是格式化字符串中的一种格式化选项,用于表示通用格式。它会根据阈值的类型自动选择合适的表示方式,并去除多余的零和小数点。具体来说,对于整数类型的阈值,它会显示为普通整数的形式,如 5、10、100 等。而对于浮点数类型的阈值,它会显示为一般的浮点数格式,如 0.5、1.0、2.5 等。在这个过程中,多余的零和小数点会被去除。
最后,该方法返回处理后的字符串表示形式。
该方法的作用是将阈值转换为特定的字符串表示形式,可能是为了后续的指标命名或其他需要使用特定格式的字符串的目的。由于该方法是私有方法(以单个下划线开头),它在类外部不可直接访问,只能在类内部被调用。
setup
def setup(self, trainer, pl_module, stage):
# Setup scalar metrics
scalar_metrics = {}
scalar_metrics["mse"] = metrics.MeanSquaredError()
scalar_metrics["mae"] = metrics.MeanAverageError()
for threshold in self.thresholds:
csi = metrics.CriticalSuccessIndex(threshold=threshold)
scalar_metrics[f"csi_{self._threshold_str(threshold)}"] = csi
scalar_metrics["avg_csi"] = metrics.AverageCriticalSuccessIndex(
thresholds=self.thresholds
)
if self.probabilistic:
scalar_metrics["crps"] = metrics.ContinuousRankedProbabilityScore(
self.buckets
)
# Create metric collections and put metrics on module to automatically place on correct device
val_scalar_metrics = torchmetrics.MetricCollection(scalar_metrics)
pl_module.val_metrics = val_scalar_metrics.clone(prefix="val/")
# Lead time metrics
lead_time_metrics = {}
lead_time_metrics[f"mse"] = metrics.MeanSquaredError(
num_leadtimes=self.num_leadtimes
)
for threshold in self.thresholds:
csi = metrics.CriticalSuccessIndex(
threshold=threshold, num_leadtimes=self.num_leadtimes
)
lead_time_metrics[f"csi_{self._threshold_str(threshold)}"] = csi
lead_time_metrics["avg_csi"] = metrics.AverageCriticalSuccessIndex(
thresholds=self.thresholds, num_leadtimes=self.num_leadtimes
)
pl_module.lead_time_metrics = torchmetrics.MetricCollection(lead_time_metrics)
在setup
方法中,主要进行了以下操作:
- 设置标量指标(scalar metrics):
- 创建一个空字典
scalar_metrics
,用于存储标量指标。 - 向
scalar_metrics
字典中添加均方误差(MeanSquaredError)和平均绝对误差(MeanAverageError)指标。 - 使用阈值列表(
self.thresholds
)循环遍历,为每个阈值创建关键成功指数(CriticalSuccessIndex)指标,并将其添加到scalar_metrics
字典中。在添加时,指标的名称使用了f"csi_{self._threshold_str(threshold)}"
的格式,其中self._threshold_str(threshold)
将阈值转换为特定的字符串表示形式。 - 添加平均关键成功指数(AverageCriticalSuccessIndex)指标到
scalar_metrics
字典中,其中的阈值使用了阈值列表(self.thresholds
)。 - 如果
self.probabilistic
为True,则添加连续排名概率评分(ContinuousRankedProbabilityScore)指标到scalar_metrics
字典中,其中的桶(buckets)参数使用了self.buckets
。
- 创建一个空字典
- 创建指标集合(
MetricCollection
)并将指标放入模块(pl_module)中:MetricCollection
是torchmetrics
的一个方法,接收字典输入,创建指标集合。- 使用
scalar_metrics
字典创建标量指标集合(val_scalar_metrics
)。 - 使用
val_scalar_metrics.clone(prefix="val/")
创建一个带有前缀的克隆集合,前缀为"val/"。 - 将克隆的标量指标集合赋值给模块的
val_metrics
属性,用于在验证过程中记录和计算指标。
- 设置引导时间指标(lead time metrics):
- 创建一个空字典
lead_time_metrics
,用于存储引导时间指标。 - 向
lead_time_metrics
字典中添加均方误差(MeanSquaredError)指标,其中的引导时间数量使用了self.num_leadtimes
。 - 使用阈值列表(
self.thresholds
)循环遍历,为每个阈值创建引导时间关键成功指数(CriticalSuccessIndex)指标,并将其添加到lead_time_metrics
字典中。在添加时,指标的名称使用了f"csi_{self._threshold_str(threshold)}"
的格式,其中self._threshold_str(threshold)
将阈值转换为特定的字符串表示形式。 - 添加平均关键成功指数(AverageCriticalSuccessIndex)指标到
lead_time_metrics
字典中,其中的阈值使用了阈值列表(self.thresholds
)和引导时间数量(self.num_leadtimes
)。 - 将引导时间指标集合(
lead_time_metrics
)赋值给模块的lead_time_metrics
属性,用于在验证过程中记录和计算引导时间指标。
- 创建一个空字典
总的来说,setup
方法主要用于设置回调函数中的指标,包括标量指标和引导时间指标。它创建了相应的指标对象,并将它们放入模块中,以便在训练过程中使用和记录。
on_validation_batch_end
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
):
"""Called after each validation batch with scalar and lead time metrics"""
_, label, metadata = batch
pl_module.val_metrics(outputs, label, metadata["target"]["mask"])
pl_module.lead_time_metrics(outputs, label, metadata["target"]["mask"])
- 获取批次数据:从
batch
参数中解包获取到三个值,即_
(不使用)、label
和metadata
。这些值通常代表了模型输出、标签和元数据等。 - 计算标量指标:通过调用模块(
pl_module
)的val_metrics
指标集合,传递模型输出、标签和目标掩码(metadata["target"]["mask"]
),来计算标量指标的值。 - 计算引导时间指标:通过调用模块的
lead_time_metrics
指标集合,传递模型输出、标签和目标掩码,来计算引导时间指标的值。
通过调用指标集合的方法,可以将模型的输出、标签和目标掩码传递给指标集合,以便计算相应的指标值。这些指标值将用于后续的记录和评估过程。
on_validation_epoch_end
def on_validation_epoch_end(self, trainer, pl_module):
# Log validation scalar metrics
pl_module.log_dict(
pl_module.val_metrics, on_step=False, on_epoch=True, sync_dist=True
)
# Compute and log lead time metrics
lead_time_metrics = pl_module.lead_time_metrics.compute()
lead_time_metrics_dict = {}
wandb_data = []
for metric_name, arr in lead_time_metrics.items():
# Add to logging dictionary
for leadtime, value in enumerate(arr):
lead_time_metrics_dict[f"val_time/{metric_name}_{leadtime+1}"] = value
# Save to file (tensorboard)
if self.logging == "tensorboard":
file_path = os.path.join(
pl_module.logger.log_dir, f"val_lead_time_{metric_name}.pt"
)
torch.save(arr.cpu(), file_path)
# Generate table for wandb
elif self.logging == "wandb":
columns = ["metric"] + [f"t_{i+1}" for i in range(len(arr))]
wandb_data.append([metric_name] + arr.tolist())
# Save table in wandb
if self.logging == "wandb":
pl_module.logger.log_table(
key="leadtimes", columns=columns, data=wandb_data
)
# Save lead time metrics over time
pl_module.log_dict(
lead_time_metrics_dict, on_step=False, on_epoch=True, sync_dist=True
)
pl_module.lead_time_metrics.reset()
- 记录和保存标量指标:
- 使用
pl_module.val_metrics
指标集合,通过调用模块的log_dict
方法,将标量指标的值记录到日志中。 - 设置
on_step=False
和on_epoch=True
,以确保在验证周期结束时记录指标的值。 - 使用
sync_dist=True
来同步跨多个设备的指标值。
- 使用
- 计算和记录引导时间指标:
- 使用
pl_module.lead_time_metrics
指标集合的compute
方法,计算引导时间指标的值。 - 创建一个空字典
lead_time_metrics_dict
,用于存储引导时间指标的名称和值。 - 创建一个空列表
wandb_data
,用于存储生成表格所需的数据。 - 遍历引导时间指标集合中的每个指标和对应的值:
- 将指标的名称和对应的值添加到
lead_time_metrics_dict
字典中,以便后续的记录和保存。 - 如果
self.logging
为"tensorboard",则将引导时间指标的值保存到文件中,文件名为val_lead_time_{metric_name}.pt
。 - 如果
self.logging
为"wandb",则生成一个表格所需的数据,其中包括指标名称和对应的值。
- 将指标的名称和对应的值添加到
- 如果
self.logging
为"wandb",则将生成的表格数据使用pl_module.logger.log_table
方法保存到wandb中,其中的key
表示表格的唯一标识,columns
表示表格的列名,data
表示表格的数据。
- 使用
- 记录引导时间指标的值:
- 使用
pl_module
的log_dict
方法,将引导时间指标的名称和值记录到日志中。 - 设置
on_step=False
和on_epoch=True
,以确保在验证周期结束时记录指标的值。 - 使用
sync_dist=True
来同步跨多个设备的指标值。
- 使用
- 重置引导时间指标集合:
- 使用
pl_module.lead_time_metrics
指标集合的reset
方法,重置引导时间指标的状态,以便在下一个验证周期开始时重新计算。
- 使用
metrics
整体介绍
- 继承自
torchmetrics
中的Metric
类,重写了full_state_update
和higher_is_better
两个属性、update
和compute
两个方法。
在类的定义中,full_state_update
被设置为False
,表示不需要完全状态更新;higher_is_better
被设置为True
,表示指标的值越高越好。
在PyTorch的Metric
类中,通常会定义一些状态变量,用于保存指标计算过程中的中间结果。这些状态变量可以在每次更新指标时被更新。而完全状态更新是指每次更新指标时,都会将所有的状态变量进行更新。然而,并不是所有的指标都需要进行完全状态更新。有些指标的计算只依赖于最近一次更新的状态,而不需要考虑之前的状态。在这种情况下,可以将full_state_update
设置为False
,以优化计算性能。这次计算的CSI指标跟之前的状态就无关,因此不需要完全状态更新。
在update
方法中,接受了三个参数prediction
、label
和mask
,用于更新指标的计算。根据阈值列表和预测结果,将预测结果转换为二进制形式,并根据reduce_time
的值进行不同的操作。
在compute
方法中,计算了平均关键成功指数(CSI),即真阳性(true positives)除以真阳性和假预测(false guesses)之和的平均值。
init
- 接下来,根据传入的参数
thresholds
和num_leadtimes
的值,选择不同的默认值和设置self.reduce_time
的值。- 如果
num_leadtimes
为None
或者等于1,表示只有一个时间步,那么默认值default
将被设置为一个形状为(len(thresholds),)
的全零张量,并且self.reduce_time
将被设置为True
,表示需要减少时间维度。 - 如果
num_leadtimes
大于1,表示有多个时间步,那么默认值default
将被设置为一个形状为(len(thresholds), num_leadtimes)
的全零张量,并且self.reduce_time
将被设置为False
,表示不需要减少时间维度。 - 如果
num_leadtimes
小于等于0,则会抛出ValueError
异常,提示num_leadtimes
必须大于0。
- 如果
- 将传入的
thresholds
参数赋值给self.thresholds
属性,以便在后续的计算中使用。 - 通过调用
self.add_state
方法,将名为"true_positives"和"false_guesses"的状态变量添加到指标类中。这两个状态变量的默认值都是通过default.clone()
来创建的,同时设置了分布式合并函数dist_reduce_fx
为"sum"- 关于
dist_reduce_fx
,Metric类中使用分布式合并函数的目的是支持在分布式计算环境中进行指标的计算和合并,在分布式计算环境中,通常有多个计算节点或进程同时进行计算任务。每个节点或进程都可能独立地计算指标的一部分,并生成局部的状态变量。为了得到整体的指标结果,需要将各个节点或进程上计算得到的状态变量进行合并。
- 关于
update
更新状态变量。
首先,根据阈值列表self.thresholds
,使用enumerate
函数遍历阈值列表的索引和值,因为CSI
指标的计算在不同thresholds
下是不同的。
接下来,将预测结果prediction
的强度(intensity
)赋给变量pred
。
然后,将pred
和label
转换为二进制形式。将pred
中大于等于当前thresholds
的元素设置为真(True),其余为假(False)。同样,将label
中大于等于当前thresholds
的元素设置为真,其余为假。
接着,根据self.reduce_time
的值进行不同的操作。
- 如果
self.reduce_time
为True,表示只有一个时间步,那么将根据mask
对pred
和lab
进行掩码操作,即将掩码为真(True)的位置从pred
和lab
中剔除。 - 如果
self.reduce_time
为False,表示有多个时间步,那么通过重新排列张量的维度,将pred
和lab
的时间维度放到最后的位置,即将形状由"b c t h w"变为"(b c h w) t"。同时,对mask
进行相同的重新排列操作,并使用torch.logical_and
函数将pred
和lab
与掩码取反(~m)进行逻辑与操作,以将掩码位置视为真(True)。这样可以保留其他维度的信息并考虑掩码。
最后,根据预测结果和标签计算真阳性(true positives)和假预测(false guesses)的总数。使用torch.logical_and
函数计算pred
和lab
的逻辑与,得到同时为真的位置,然后使用sum(dim=0)
对每个时间步的结果进行求和,将结果累加到self.true_positives[i]
中。使用(pred != lab)
进行逻辑不等于操作,得到不一致的位置,然后使用sum(dim=0)
对每个时间步的结果进行求和,将结果累加到self.false_guesses[i]
中。
通过循环遍历阈值列表和计算真阳性和假预测的总数,update
方法更新了指标类中的状态变量。
compute
根据状态变量计算最终指标。
utils
buckets
各种分箱策略。
config
继承自yaml库的SafeLoader类,用于解析YAML文件(里没事各种参数设定)。
data_utils
用于各种数据处理。使用的情况有:
train.py:
from w4c23.utils.data_utils import get_cuda_memory_usage, tensor_to_submission_file
sampler.py:
from w4c23.utils.data_utils import get_file
w4c_dataloader.py:
from w4c23.utils.data_utils import *
sampler
数据集中样本的抽样策略,实现重要性采样。
w4c_dataloader
读取并归一化大赛数据。
其他项目
checkpoints
保存模型参数。
configurations
保存定义模型的各种参数组合。
data
原始数据。
images
2D U-Net 架构的输出与其输入具有相同的空间维度。这意味着对于大小为 128 x 128 像素的输入序列,通过 U-Net 的前向传播将生成大小为 128 x 128 像素的输出。标签对应于大小为 42 x 42 像素的中心块。因此,为了指导降水临近预报模型,我们采用中央 42 x 42 像素块并上采样到 252 x 252 像素标签。这种裁剪和上采样是在 MetNet [9] 中引入的,这是由于输入和标签的空间分辨率不同所致,如第 3 节中所述。
解释为什么要对标签值上采样。