《Rigid and Non-rigid Motion Artifacts Reduction in X-ray CT using Attention Module》代码 复现与简单讲解

论文请参考上篇博客

代码地址:github

任务:CBCT 运动伪影去除

运行

1、github 下载解压

2、源目录下创建两个文件夹

3、下载权重放到log,下载数据放到data内,xu'ya

权重:MEGA ,log - Google 云端硬盘

数据:data - Google 云端硬盘

4、运行

python train.py
python test.py

5、结果

简单讲解 

train 和model 其实都很简单

trian.py

首先是两个占位符,可以看到就一个输入一个输出

    # Placeholders
    input_ph = tf.placeholder(tf.float32, shape=[None, None, None, 1])
    target_ph = tf.placeholder(tf.float32, shape=[None, None, None, 1])

加载网络,网络内部一会再看

    # Deblur with network
    deblur_output = network(input_ph)

连个损失,跟论文一致l1+vgg 


    # L1 loss
    loss_l1 = 1e2 * tf.reduce_mean(abs(deblur_output - target_ph))  # L1 loss

    # VGG loss
    loss_vgg = tf.zeros(1, tf.float32)
    target_resize = convert_tensor(target_ph)
    vgg_t = vgg16.Vgg16()
    vgg_t.build(target_resize)

    target_feature = [vgg_t.conv3_3, vgg_t.conv4_3]
    output_resize = convert_tensor(deblur_output)
    vgg_o = vgg16.Vgg16()
    vgg_o.build(output_resize)

    output_feature = [vgg_o.conv3_3, vgg_o.conv4_3]
    for f, f_ in zip(output_feature, target_feature):
        loss_vgg += 5*1e-5 * tf.reduce_mean(tf.subtract(f, f_) ** 2, [1, 2, 3])  # Perceptual(vgg) loss

    # Total loss & Optimizer
    loss = loss_l1 + loss_vgg
    opt = tf.train.AdamOptimizer(learning_rate=5*1e-5).minimize(loss, var_list=tf.trainable_variables())

保存 

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir=model_path)
    if ckpt:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('* Training model loaded: ' + ckpt.model_checkpoint_path)
    else:
        print('* Training model does not exist or load failed.')

训练,并返回loss

    while epoch < end_Epoch:
        for train_index in range(nTrain):

            random_index = random.randint(1, nTrain) - 1
            input_image = train_file[random_index:random_index+1, :, :, :]
            target_image = target_file[random_index:random_index+1, :, :, :]
            loss_, l1_, vgg_, _ = sess.run([loss, loss_l1, loss_vgg, opt], feed_dict={input_ph: input_image, target_ph: target_image})

            if not (train_index + 1) % (nTrain/10):
                print("Epoch:[%3d/%d] Batch:[%5d/%5d] - Loss:[%4.4f] L1:[%4.4f] VGG:[%4.4f]"
                      % (epoch, end_Epoch, (train_index+1), nTrain, loss_, l1_, vgg_))

        epoch += 1
        saver.save(sess, "%s/model.ckpt" % model_path)

        if epoch == end_Epoch:
            break

model

看向内部

def network(input_img):
    with tf.variable_scope('deblur'):
        net = slim.conv2d(input_img, 64, [5, 5], activation_fn=None)
        for layer_ in range(10):
            net = AttBlock(net)
        deblur_img = slim.conv2d(net, 1, [5, 5], activation_fn=None)
        return deblur_img

 网络:卷积+注意力块*10 +卷积

# AttBlock
def AttBlock(rgap_input):
    temp = slim.conv2d(rgap_input, 64, [5, 5], activation_fn=None)
    temp = tf.nn.relu(temp)
    res = slim.conv2d(temp, 64, [5, 5], activation_fn=None)
    res_gap = global_average_pooling(res)
    rgap_output = rgap_input + res_gap
    return rgap_output

注意力块:卷积,激活,卷积,注意力池化,残差

# attention module
def global_average_pooling(GAP_input):
    avg_pool = tf.reduce_mean(GAP_input, axis=[1, 2])
    temp = tf.expand_dims(avg_pool, 1)
    temp = tf.expand_dims(temp, 1)
    channel_weight = temp
    GAP_output = tf.multiply(GAP_input, channel_weight)
    return GAP_output

注意力池化:平均池化,加权

不好评价,总之,不好评价

  • 6
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值