cascade MRI reconstruction:evaluate

CascadeNetwork(
  (block): Sequential(
    (0): ResnetBlock(
      (layers): Sequential(
        (0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
        (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): ReLU()
        (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): ReLU()
        (8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (1): DataConsistency()
    (2): ResnetBlock(
      (layers): Sequential(
        (0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
        (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): ReLU()
        (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): ReLU()
        (8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (3): DataConsistency()
    (4): ResnetBlock(
      (layers): Sequential(
        (0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
        (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): ReLU()
        (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): ReLU()
        (8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (5): DataConsistency()
    (6): ResnetBlock(
      (layers): Sequential(
        (0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
        (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): ReLU()
        (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): ReLU()
        (8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (7): DataConsistency()
    (8): ResnetBlock(
      (layers): Sequential(
        (0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
        (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): ReLU()
        (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): ReLU()
        (8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (9): DataConsistency()
    (10): ResnetBlock(
      (layers): Sequential(
        (0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
        (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): ReLU()
        (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): ReLU()
        (8): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (11): DataConsistency()
  )
)
         gund = back_format(batch['image'].cpu().detach().numpy())
        output = model(batch)
        loss = criterion(output['image'], batch['full'])
        # mse += loss.item()  * 1000
        mse += loss.item()
        pred = back_format(output['image'].cpu().detach().numpy())
        grnd = back_format(batch['full'].cpu().detach().numpy())
        base_psnr += complex_psnr(grnd, gund)
        test_psnr += complex_psnr(grnd, pred)
        ssim_score += ssim(pred, grnd)
        if draw:
            pred = pred[0] # (256, 256)
            grnd = grnd[0]
            err = np.abs(pred - grnd) ** 2
            err -= err.min()
            err /= err.max()
            err *= 255
            err = err.astype(np.uint8)
            err = cv2.applyColorMap(err, 2)
            plt.imsave(os.path.join(kwargs['output_path'], str(i) + '.png'), err)

Input

fully-sampled image

a = back_format(batch['full'].cpu().detach().numpy())

a = np.abs(np.squeeze(a)) # (256, 256)
plt.imshow(a)
plt.colorbar()
plt.axis(False)

a
在这里插入图片描述
在这里插入图片描述

undersampled image

b = back_format(batch[‘image’].cpu().detach().numpy())
在这里插入图片描述
在这里插入图片描述

under sampled k-space

#因为batch['k']是(1,256,256,2),.permute(0, 3, 1, 2)转成(1,2,256,256)
c=back_format(batch['k'].permute(0, 3, 1, 2).cpu().detach().numpy())
c = np.abs(np.squeeze(c))
plt.imshow(np.log10(c+1e-17))
plt.colorbar()
plt.axis(False)

在这里插入图片描述
在这里插入图片描述

4X sampling mask

d=back_format(batch['mask'].permute(0, 3, 1, 2).cpu().detach().numpy())
d = np.abs(np.squeeze(d))
plt.imshow(np.log10(d+1e-17))
plt.colorbar()
plt.axis(False)

4倍加速,256/4 = 64条红线
在这里插入图片描述
在这里插入图片描述

fully k-space

e1 = batch['full'].permute(0, 2, 3, 1)
e2 = torch.fft(e1, signal_ndim=2, normalized=True)
e=back_format(e2.permute(0, 3, 1, 2).cpu().detach().numpy())
e = np.abs(np.squeeze(e))
plt.imshow(np.log10(e+1e-17))
plt.colorbar()
plt.axis(False)

在这里插入图片描述
在这里插入图片描述

Output

reconstructed image

f = back_format(out['image'].cpu().detach().numpy())
f = np.abs(np.squeeze(f))
plt.imshow(f)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

reconstructed k-space

g1 = out['image'].permute(0, 2, 3, 1)
g2 = torch.fft(g1, signal_ndim=2, normalized=True)
g=back_format(g2.permute(0, 3, 1, 2).cpu().detach().numpy())
g = np.abs(np.squeeze(g))
plt.imshow(np.log10(g+1e-17),cmap='Greys')
plt.colorbar()
plt.axis(False)
plt.show()

在这里插入图片描述
在这里插入图片描述

calculate PSNR

a = back_format(batch['image'].cpu().detach().numpy())
b = back_format(out['image'].cpu().detach().numpy())
c = back_format(batch['full'].cpu().detach().numpy())

base_psnr += complex_psnr(c, a)
test_psnr += complex_psnr(c, b)
def complex_psnr(x, y, peak='normalized'):
    '''
    x: reference image
    y: reconstructed image
    peak: normalised or max
    Notice that ``abs'' squares
    Be careful with the order, since peak intensity is taken from the reference
    image (taking from reconstruction yields a different value).
    '''
    # mse = np.mean(np.abs(x - y)**2)
    mse = np.square(np.abs(np.subtract(x, y))).mean()
    if peak == 'max':
        return 10*np.log10(np.max(np.abs(x))**2/mse)
    else:
        return 10*np.log10(1./mse)

base_psnr

#注意这里的a,要提前copy.copy(batch),不然重建出来的会覆盖掉battch['image']
base_square = np.abs(c - a)**2
base_square = np.squeeze(base_square)
plt.imshow(base_square)
plt.colorbar()
plt.axis(False)

在这里插入图片描述
在这里插入图片描述

base_mse = np.mean(np.abs(c - a)**2)

Out[70]: 0.0014865209
10*np.log10(1./base_mse)
Out[74]: 28.278289906022586

test_psnr

test_square = np.abs(c - b)**2
test_square = np.squeeze(test_square)
plt.imshow(test_square)
plt.colorbar()
plt.axis(False)

在这里插入图片描述
在这里插入图片描述

test_mse = np.mean(np.abs(c - b)**2)

Out[97]: 7.805396e-05
10*np.log10(1./test_mse)
Out[98]: 41.07605041353657

calculate SSIM

def ssim(x, y):
    score = 0
    for i in range(x.shape[0]):
        tempx = np.abs(x[i])
        tempy = np.abs(y[i])
        # score += compare_ssim(tempx, tempy)
        score += structural_similarity(tempx, tempy)
    score /= x.shape[0]
    return score
ssim(pred, grnd)
Out[99]: 0.989175586596038

draw error image

base_error

complex_psnr(grnd, gund)
Out[21]: 28.278289906022586
gund = gund[0]
grnd = grnd[0]
base_err = np.abs(gund - grnd) ** 2 # (256, 256)
plt.imshow(base_err)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

base_err -= base_err.min()
plt.imshow(base_err2)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

base_err /= base_err.max()
plt.imshow(base_err)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

base_err *= 255 # (256, 256) float32
plt.imshow(base_err)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

base_err = base_err.astype(np.uint8) # (256, 256) uint8
plt.imshow(base_err)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

base_err = cv2.applyColorMap(base_err, 2)
plt.imshow(base_err)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

test_error

complex_psnr(grnd, pred)
Out[41]: 41.07605041353657
pred = pred[0]
grnd = grnd[0]
test_err = np.abs(pred - grnd) ** 2 # (256, 256)
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

test_err -= test_err.min()
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

test_err /= test_err.max()
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

test_err *= 255 # (256, 256) float32
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

test_err = test_err.astype(np.uint8) # (256, 256) uint8
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

test_err = cv2.applyColorMap(test_err, 2)
plt.imshow(test_err)
plt.colorbar()
plt.axis(False)

在这里插入图片描述

model(batch)流程

在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值