evaluate
CascadeNetwork(
(block): Sequential(
(0): ResnetBlock(
(layers): Sequential(
(0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): ReLU()
(6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU()
(8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(1): DataConsistency()
(2): ResnetBlock(
(layers): Sequential(
(0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): ReLU()
(6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU()
(8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(3): DataConsistency()
(4): ResnetBlock(
(layers): Sequential(
(0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): ReLU()
(6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU()
(8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(5): DataConsistency()
(6): ResnetBlock(
(layers): Sequential(
(0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): ReLU()
(6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU()
(8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(7): DataConsistency()
(8): ResnetBlock(
(layers): Sequential(
(0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): ReLU()
(6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU()
(8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(9): DataConsistency()
(10): ResnetBlock(
(layers): Sequential(
(0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): ReLU()
(6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU()
(8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(11): DataConsistency()
)
)
gund = back_format(batch['image'].cpu().detach().numpy())
output = model(batch)
loss = criterion(output['image'], batch['full'])
# mse += loss.item() * 1000
mse += loss.item()
pred = back_format(output['image'].cpu().detach().numpy())
grnd = back_format(batch['full'].cpu().detach().numpy())
base_psnr += complex_psnr(grnd, gund)
test_psnr += complex_psnr(grnd, pred)
ssim_score += ssim(pred, grnd)
if draw:
pred = pred[0] # (256, 256)
grnd = grnd[0]
err = np.abs(pred - grnd) ** 2
err -= err.min()
err /= err.max()
err *= 255
err = err.astype(np.uint8)
err = cv2.applyColorMap(err, 2)
plt.imsave(os.path.join(kwargs['output_path'], str(i) + '.png'), err)
Input
fully-sampled image
a = back_format(batch['full'].cpu().detach().numpy())
a = np.abs(np.squeeze(a)) # (256, 256)
plt.imshow(a)
plt.colorbar()
plt.axis(False)
a
undersampled image
b = back_format(batch[‘image’].cpu().detach().numpy())
under sampled k-space
#因为batch['k']是(1,256,256,2),.permute(0, 3, 1, 2)转成(1,2,256,256)
c=back_format(batch['k'].permute(0, 3, 1, 2).cpu().detach().numpy())
c = np.abs(np.squeeze(c))
plt.imshow(np.log10(c+1e-17))
plt.colorbar()
plt.axis(False)
4X sampling mask
d=back_format(batch['mask'].permute(0, 3, 1, 2).cpu().detach().numpy())
d = np.abs(np.squeeze(d))
plt.imshow(np.log10(d+1e-17))
plt.colorbar()
plt.axis(False)
4倍加速,256/4 = 64条红线
fully k-space
e1 = batch['full'].permute(0, 2, 3, 1)
e2 = torch.fft(e1, signal_ndim=2, normalized=True)
e=back_format(e2.permute(0, 3, 1, 2).cpu().detach().numpy())
e = np.abs(np.squeeze(e))
plt.imshow(np.log10(e+1e-17))
plt.colorbar()
plt.axis(False)
Output
reconstructed image
f = back_format(out['image'].cpu().detach().numpy())
f = np.abs(np.squeeze(f))
plt.imshow(f)
plt.colorbar()
plt.axis(False)
reconstructed k-space
g1 = out['image'].permute(0, 2, 3, 1)
g2 = torch.fft(g1, signal_ndim=2, normalized=True)
g=back_format(g2.permute(0, 3, 1, 2).cpu().detach().numpy())
g = np.abs(np.squeeze(g))
plt.imshow(np.log10(g+1e-17),cmap='Greys')
plt.colorbar()
plt.axis(False)
plt.show()
calculate PSNR
a = back_format(batch['image'].cpu().detach().numpy())
b = back_format(out['image'].cpu().detach().numpy())
c = back_format(batch['full'].cpu().detach().numpy())
base_psnr += complex_psnr(c, a)
test_psnr += complex_psnr(c, b)
def complex_psnr(x, y, peak='normalized'):
'''
x: reference image
y: reconstructed image
peak: normalised or max
Notice that ``abs'' squares
Be careful with the order, since peak intensity is taken from the reference
image (taking from reconstruction yields a different value).
'''
# mse = np.mean(np.abs(x - y)**2)
mse = np.square(np.abs(np.subtract(x, y))).mean()
if peak == 'max':
return 10*np.log10(np.max(np.abs(x))**2/mse)
else:
return 10*np.log10(1./mse)
base_psnr
#注意这里的a,要提前copy.copy(batch),不然重建出来的会覆盖掉battch['image']
base_square = np.abs(c - a)**2
base_square = np.squeeze(base_square)
plt.imshow(base_square)
plt.colorbar()
plt.axis(False)
base_mse = np.mean(np.abs(c - a)**2)
Out[70]: 0.0014865209
10*np.log10(1./base_mse)
Out[74]: 28.278289906022586
test_psnr
test_square = np.abs(c - b)**2
test_square = np.squeeze(test_square)
plt.imshow(test_square)
plt.colorbar()
plt.axis(False)
test_mse = np.mean(np.abs(c - b)**2)
Out[97]: 7.805396e-05
10*np.log10(1./test_mse)
Out[98]: 41.07605041353657
calculate SSIM
def ssim(x, y):
score = 0
for i in range(x.shape[0]):
tempx = np.abs(x[i])
tempy = np.abs(y[i])
# score += compare_ssim(tempx, tempy)
score += structural_similarity(tempx, tempy)
score /= x.shape[0]
return score
ssim(pred, grnd)
Out[99]: 0.989175586596038
draw error image
base_error
complex_psnr(grnd, gund)
Out[21]: 28.278289906022586
gund = gund[0]
grnd = grnd[0]
base_err = np.abs(gund - grnd) ** 2 # (256, 256)
plt.imshow(base_err)
plt.colorbar()
plt.axis(False)
base_err -= base_err.min()
plt.imshow(base_err2)
plt.colorbar()
plt.axis(False)
base_err /= base_err.max()
plt.imshow(base_err)
plt.colorbar()
plt.axis(False)
base_err *= 255 # (256, 256) float32
plt.imshow(base_err)
plt.colorbar()
plt.axis(False)
base_err = base_err.astype(np.uint8) # (256, 256) uint8
plt.imshow(base_err)
plt.colorbar()
plt.axis(False)
base_err = cv2.applyColorMap(base_err, 2)
plt.imshow(base_err)
plt.colorbar()
plt.axis(False)
test_error
complex_psnr(grnd, pred)
Out[41]: 41.07605041353657
pred = pred[0]
grnd = grnd[0]
test_err = np.abs(pred - grnd) ** 2 # (256, 256)
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)
test_err -= test_err.min()
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)
test_err /= test_err.max()
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)
test_err *= 255 # (256, 256) float32
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)
test_err = test_err.astype(np.uint8) # (256, 256) uint8
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)
test_err = cv2.applyColorMap(test_err, 2)
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)
model(batch)流程