比赛是来自于Google Brain - Ventilator Pressure Prediction。
引言
在部分章节中,我可能不会说得很详细,毕竟我的毕业论文还没发表,待发表过后会择期填补完整。
需要注意的是,本人并没有参加这个比赛,而是对这个比赛的某个金牌获胜者的方案进行一次复现。进而应用在本人的毕业设计中去,但请放心本人的毕业论文中已经对该作者进行了引用。
在方案的选择上,首先需要刨除那些数据泄露的方案,这些方案虽然有可取之处但对于实际应用并没有什么意义。其次需要去找那些有提供GitHub网址的方案,毕竟根据思路进行方案的复现略微困难。
最后是对方案进行改进或者是使它更加适用于自己的生产应用。
方案介绍
我所选择的方案是JUN KODA金牌方案PyTorch LSTM with TensorFlow-like initialization。该方案使用的是Pytorch框架和LSTM算法。
数据观察
字段 | 含义 | 备注 |
---|---|---|
id | 整个文件中的全局唯一时间步长标识符 | 整体数据的计数 |
breath_id | 全局唯一的呼吸时间步长 | 周期呼吸的计数 |
R | 气道阻力(cmH2O/L/S) | 指单位压力改变时所引起的肺容积的改变,它代表了胸腔压力改变对肺容积的影响 |
C | 肺顺应性(mL/cmH2O) | 物理上,这是每一次压力变化的体积变化。直觉上,人们可以想象同样的气球例子。我们可以通过改变气球胶乳的厚度来改变C,更高的C有更薄的胶乳,更容易吹 |
time_step | 实际的时间戳 | |
u_in | 吸气电磁阀的控制输入 | 范围从0到100,类似于一个滑动变阻器 |
u_out | 探测电磁阀的控制输入 | 0或1,0是吸气,1是呼气 |
Pressure | 呼吸回路中测得的气道压力 | 单位为cmH2O |
。。。(择期填补)
特征工程
。。。(择期填补)
网络搭建
。。。(择期填补)
损失值定义
。。。(择期填补)
开始训练
。。。(择期填补)
改进
在平常的模型训练上,我们一般都是将数据一次性准备齐全,然后丢到网络中去训练。同理,在数据预测上,我们同样是将数据一次性准备齐全,然后丢到模型中去预测。如以下代码所示:
test = pd.read_csv(di + 'test.csv', nrows=n)
features = create_features(test)
features = rs.transform(features)
X_test = features.reshape(-1, 80, features.shape[-1])
y_test = np.zeros(len(features)).reshape(-1, 80)
w_test = 1 - test.u_out.values.reshape(-1, 80)
dataset_test = Dataset(X_test, y_test, w_test)
loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size)
y_pred_folds = np.zeros((len(test), 5), dtype=np.float32)
for ifold in range(5):
model = Model(input_size)
model.to(device)
filename = '/kaggle/input/pytorchlstmwithtensorflowlikeinitialization/' \
'model%d.pth' % ifold
model.load_state_dict(torch.load(filename, map_location=device))
model.eval()
y_preds = []
for x, y, _ in loader_test:
x = x.to(device)
with torch.no_grad():
y_pred = model(x).squeeze()
y_preds.append(y_pred.cpu().numpy())
y_preds = np.concatenate(y_preds, axis=0)
y_pred_folds[:, ifold] = y_preds.flatten()
submit.pressure = np.mean(y_pred_folds, axis=1)
submit.to_csv('submission.csv', index=False)
print('submission.csv written')
可以看到,代码将无数周期的呼吸数据丢到了模型中去,这显然不合理,呼吸机的使用场景是实时的。所以我们要将这代码修改一下。