def test_step(self, batch, batch_idx): ct, mask, name = batch outputs = self.forward(ct) self.measure(batch, outputs) print(ct.shape) ########################################################################################### for output in outputs: # print(output.shape) # print(name) output = torch.argmax(torch.softmax( # torch.argmax()函数中dim表示该维度会消失。 output, dim=0), dim=0).squeeze(0) # squeeze()表示维度为1的维度消失 # 0,0,0不可变 out = output.cpu().numpy() # 数据类型转换 # print(out.shape) save_image = 'save_image' if not os.path.exists(save_image): os.makedirs(save_image) save_image_path = os.path.join(save_image, f'{batch_idx}.jpg') # print(save_image_path) plt.imsave(save_image_path, out) ################################################################################################ for m in mask: # print(ct.shape) # print(name) output = torch.argmax(torch.softmax( # torch.argmax()函数中dim表示该维度会消失。 m, dim=0), dim=0).squeeze(0) # squeeze()表示维度为1的维度消失 # 0,0,0不可变 out = output.cpu().numpy() # 数据类型转换 # print(out.shape) save_image = 'save_mask_image' if not os.path.exists(save_image): os.makedirs(save_image) save_image_path = os.path.join(save_image, f'{batch_idx}.jpg') # print(save_image_path) plt.imsave(save_image_path, out) ################################################################################ for c in ct: print(c.shape) # print(name) output = c.squeeze(0) # squeeze()表示维度为1的维度消失 # 0,0,0不可变 out = output.cpu().numpy() # 数据类型转换 print(out.shape) save_image = 'save_ct_image' if not os.path.exists(save_image): os.makedirs(save_image) save_image_path = os.path.join(save_image, f'{batch_idx}.jpg') # print(save_image_path) # plt.imsave(save_image_path, out) imageio.imwrite(save_image_path, out) # cv2.imwrite(save_image_path, out) # plt.imshow(out) # plt.show()
pytorch-lightning网络生成的张量,转换为JPG文件,并保存
于 2022-05-06 18:29:03 首次发布