CT-CTA 代码不理解的点

由于是将训练CT数据的模型用来跑MRI,因此有些操作不是很理解,并且也不会,请教哈各位大佬

求助!!!!!!!!!!!!

CT和MRI都有窗的概念,那么MRI中窗宽和窗位怎么调整呢?调整的原则和依据是什么?

Question1

首先使用一个生成模型 netG_A2B 将输入 real_A2 转换成输出 fake_B,这通常是在如图像到图像的转换任务中常见的做法,例如在使用对抗生成网络(GANs)来增强医学图像或改变图像风格的应用中。然后利用函数 to_windowdatareal_B 和生成的 fake_B 通过窗宽(WW)和窗位(WC)调整进行处理,以达到一定的视觉效果或分析目的。之后,代码继续对处理后的图像进行二值化(bbcc),最后再次调整像素值。

每一步都有其特定目的,旨在改善图像的质量或突出某些特征,以便更好地分析或理解图像内容。

     # 生成模型的应用(`self.netG_A2B(real_A2)`)将输入图像 `real_A2` 转换成生成图像 `fake_B`,
     # 这可能是为了增强图像细节,去噪声,或将图像从一种风格转换到另一种风格等。
     fake_B = self.netG_A2B(real_A2)#real_A2
	 fake_B = fake_B.detach().cpu().numpy().squeeze()
	
	 WC=ds.WindowCenter
	 WW=ds.WindowWidth
	 print("test: WC=", WC, " WW=", WW)
	 # test: WC= 230  WW= 565
	
	 # 窗宽和窗位调整(`to_windowdata`)调整图像的灰度范围,使得感兴趣的细节更加突出。
	 # 这在处理医学影像,如CT或MR图像时非常重要,因为不同组织和结构可能只在特定的灰度范围内清晰可见。**
	 b=to_windowdata(real_B,WC,WW)
	 bb=copy.deepcopy(b)
	 
	 # 二值化(`bb[bb<0.3]=0` 和 `bb[bb>=0.3]=1`)有助于突出重要的特征或边缘,并简化后续处理步骤。
	 # 通过设置阈值(例如`0.3`),将图像中的像素分为两类,旨在区分感兴趣的目标和背景或噪声。  
	 bb[bb<0.3]=0
	 bb[bb>=0.3]=1

     # 图像掩膜和过滤(`b=b*bb` 和 `c=c*cc`)使用了二值化图像作为掩膜,只保留原图像中特定的区域或像素值。
     # 这是一种常见的做法,用于去除不关心的部分或降低噪声的影响。
	 b=b*bb

     # 重新调整像素值(`b[b==0]=-1` 和 `c[c==0]=-1`)将特定像素值设置为 `-1`(或其他特定值)可能是为了在后续的处理或分析中标记这些像素,
     # 例如,可能希望在进行统计分析或进一步处理时忽略这些值。
	 b[b==0]=-1

     c=to_windowdata(fake_B,WC,WW)*bb#to_windowdata(fake_B,WC,WW)
     cc=copy.deepcopy(c)
     cc[cc<0.3]=0
     cc[cc>=0.3]=1
     c=c*cc
     c[c==0]=-1

Question2

   newimg = (fake_BB + 1) * 0.5 * 4095
   ds.SeriesInstanceUID = dsa
   # newimg[newimg == 0] = -2000
   if ds[0x0028, 0x0100].value == 16:  # 如果dicom文件矩阵是16位格式
       newimg = newimg.astype(np.int16)  # newimg 是图像矩阵 ds是dcm uint16
   elif ds[0x0028, 0x0100].value == 8:
       newimg = newimg.astype(np.int8)
   else:
       raise Exception("unknow Bits Allocated value in dicom header")
   # ds.dtype=int16
   ds.PixelData = newimg.tobytes()  # 替换矩阵
   shutil.copy(file_path, file_path0)
   shutil.copy(file_path.replace('SE0','SE1'), file_path1)
   """然后再将其pred保存下来"""
   pydicom.dcmwrite(out_path2+name,ds)

Question3:

b=to_windowdata(real_B,WC,WW)

to_windowdata 用于将图像数据进行窗口调整,使医学图像(如CT图像)的特定区域更清晰可见。此调整是通过设定窗宽(Window Width, WW)和窗位(Window Center, WC)实现的,以突出在此范围内的细节,同时抑制范围外的信息。

初始化: 函数接收3个参数:image(原始图像数据,是一个NumPy数组),WC(窗位),以及WW(窗宽)。

这个函数尤其在医学图像处理中非常有用,用于增强图像的特定区域(结构),便于医生或诊断系统分析图像。
整个流程是图像处理中常用的窗口调整技术的典型实现,能有效地改善图像的可视化效果,突出感兴趣的细节。

def to_windowdata(image,WC,WW):
    print("to_windowdata1: image=",type(image), image.shape) 
    # to_windowdata1: image= <class 'numpy.ndarray'> ( 512, 512)
    
    # 对图像进行预处理
    image = (image + 1) * 0.5 * 4095
    # 将图像的像素值进行缩放,此举使像素值分布在0和4095之间
    image[image == 0] = -2000
    # 并对0像素值进行特殊处理(置为-2000)以代表背景,
    image=image-1024
    # 然后从所有像素值中减去1024以调整范围。
    print("to_windowdata2: image=",type(image), image.shape) 
    # to_windowdata2: image= <class 'numpy.ndarray'> ( 512, 512)

    # 窗口调整:
    center = WC #40 400//60 300
    width = WW# 200
    """计算窗口的最小值(win_min)和最大值(win_max),基于提供的窗宽和窗位。
若WC或WW输入为不符预期(如列表或元组等),则在except块中处理,通过取索引0的值来适配。"""
    try:
        win_min = (2 * center - width) / 2.0 + 0.5
        win_max = (2 * center + width) / 2.0 + 0.5
        # print("to_windowdata 正常情况: win_min=", win_min, " win_max=",win_max)
        # to_windowdata 正常情况: win_min= -52.0  win_max= 513.0
    except:
        # print(WC[0])
        # print(WW[0])
        center = WC[0]  # 40 400//60 300
        width = WW[0]  # 200
        win_min = (2 * center - width) / 2.0 + 0.5
        win_max = (2 * center + width) / 2.0 + 0.5
        # print("to_windowdata 异常情况: win_min=", win_min, " win_max=", win_max)


    # 像素值重映射:
    # 计算映射因子(dFactor),这个因子用来将图像像素值映射到0-255范围内,以适合8位灰度图像表示。基于计算得到的win_min和映射因子,调整图像像素值,超出0-255范围的值会被截断到边界值。
    dFactor = 255.0 / (win_max - win_min)
    image = image - win_min
    image = np.trunc(image * dFactor)
    image[image > 255] = 255
    image[image < 0] = 0
    image=image/255#np.uint8(image)

    # 后处理:图像像素值被标准化到-1到1的范围内,为了后续处理或显示做准备。
    image = (image - 0.5)/0.5
    # print("to_windowdata6: image=", type(image), image.shape) # to_windowdata6: image= <class 'numpy.ndarray'> ( 512, 512)
    return image


 """try:正常情况下,程序计划执行的语句。
       except:程序异常是执行的语句。
       else:程序无异常即try段代码正常执行后会执行该语句。
       finally:不管有没有异常,都会执行的语句。"""

Queation 4:

生成器(Generator)的神经网络模型,用于图像处理中的生成对抗网络(GAN)等任务。

在初始化函数__init__中,生成器接受input_nc个通道的输入和output_nc个通道的输出,同时定义了残差块的数量n_residual_blocks。接着定义了生成器的结构,包括初始卷积块、下采样、残差块、上采样和最终的输出层。

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()
        # Initial convolution block
        # 初始卷积块(model_head):通过反射填充层(nn.ReflectionPad2d)进行填充,然后使用7x7的卷积核对输入进行卷积
        # 进行实例归一化(nn.InstanceNorm2d)和ReLU激活函数。这一部分用于提取输入图像的基本特征
        model_head = [nn.ReflectionPad2d(3),
                      nn.Conv2d(input_nc, 64, 7),
                      nn.InstanceNorm2d(64),
                      nn.ReLU(inplace=True)]

        # Downsampling  下采样
        # 通过两个3x3的卷积核进行下采样,即将特征图的尺寸减半
        # 同时通道数增加一倍,并使用实例归一化和ReLU激活函数。这有助于提取更高级别的特征。
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):    
            model_head += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                           nn.InstanceNorm2d(out_features),
                           nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks 残差块(model_body):
        # 定义n_residual_blocks个残差块(ResidualBlock),每个残差块的输入和输出具有相同的维度
        # 通过跳跃连接来学习残差映射,有助于避免深度网络训练中出现的梯度消失或爆炸问题。
        model_body = []
        for _ in range(n_residual_blocks):
            model_body += [ResidualBlock(in_features)]


        # Upsampling 上采样:
        # 通过两个3x3的转置卷积核进行上采样,即将特征图的尺寸放大一倍
        # 同时通道数减半,并使用实例归一化和ReLU激活函数。
        model_tail = []
        out_features = in_features // 2
        for _ in range(2):
            model_tail += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                           nn.InstanceNorm2d(out_features),
                           nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        # 输出层(model_tail):通过反射填充层和7x7的卷积核进行填充,然后使用tanh激活函数输出最终的图像。
        model_tail += [nn.ReflectionPad2d(3),
                       nn.Conv2d(64, output_nc, 7),
                       nn.Tanh()]

        self.model_head = nn.Sequential(*model_head)
        self.model_body = nn.Sequential(*model_body)
        self.model_tail = nn.Sequential(*model_tail)

    def forward(self, x):
        x = self.model_head(x)
        x = self.model_body(x)
        x = self.model_tail(x)

        return x

Queation 5

定义了一个鉴别器(Discriminator)的神经网络模型,用于生成对抗网络(GAN)中。

在初始化函数__init__中,鉴别器接受input_nc个通道的输入。在初始化网络结构时,通过一系列的卷积层和LeakyReLU激活函数来构建鉴别器的模型。
LeakyReLU激活函数在负数部分有一个小的斜率,有助于解决传统ReLU激活函数中的“神经元死亡”问题。

class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        # 4x4的卷积核进行卷积,步幅为2,填充为1,输出通道为64,然后经过LeakyReLU激活函数。
        model = [nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                 nn.LeakyReLU(0.2, inplace=True)]

        #  4x4的卷积核进行卷积,步幅为2,填充为1,输出通道为128,然后经过实例归一化(nn.InstanceNorm2d)和LeakyReLU激活函数。
        model += [nn.Conv2d(64, 128, 4, stride=2, padding=1),
                  nn.InstanceNorm2d(128),
                  nn.LeakyReLU(0.2, inplace=True)]
                  
        # 4x4的卷积核进行卷积,步幅为2,填充为1,输出通道为256,然后经过实例归一化和LeakyReLU激活函数。
        model += [nn.Conv2d(128, 256, 4, stride=2, padding=1),
                  nn.InstanceNorm2d(256),
                  nn.LeakyReLU(0.2, inplace=True)]
                  
        # 4x4的卷积核进行卷积,填充为1,输出通道为512,然后经过实例归一化和LeakyReLU激活函数。
        model += [nn.Conv2d(256, 512, 4, padding=1),
                  nn.InstanceNorm2d(512),
                  nn.LeakyReLU(0.2, inplace=True)]

        # FCN classification layer
        # 在最后的卷积层后,没有使用激活函数,直接输出一个通道,用于进行最终的二元分类。
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        # 在前向传播函数forward中,输入图像x经过模型的所有层后,进行了平均池化(F.avg_pool2d)和展平操作,最终输出一个一维的结果,用于表示输入图像是真实图像(1)还是生成图像(0)。
        # self.model = nn.Sequential(*model)表示经过所有层
        x = self.model(x)
        # Average pooling and flatten
        # x=F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
        # print(x.size())
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

Queation 6:

使用PyTorch实现的神经网络模型类,模型名称为Reg。模型的构造较为复杂,涉及到了ResUnet模块的使用。

在初始化方法__init__中,首先调用了父类的初始化方法,并设置了模型所需的一些参数,如输入图像的高度和宽度、输入通道数等。然后创建了一个ResUnet模块的实例self.offset_map,并将其发送到GPU设备上。同时,还调用了get_identity_grid方法来获取一个标准的网格。

get_identity_grid方法用来生成一个标准的identity网格,即一个标准的坐标网格。在前向传播方法forward中,首先调用self.offset_map来计算图像img_a和img_b之间的偏移变换deformations,然后将结果返回。

需要注意的是,ResUnet模块的具体实现并未提供,因此无法对整个Reg模型的功能进行准确的解释。建议查看ResUnet模块的实现以及相关的文档来更好地理解整个模型的功能和作用。

class Reg(nn.Module):
    def __init__(self,height,width,in_channels_a,in_channels_b):
        super(Reg, self).__init__()
        # height,width=256,256
        # in_channels_a,in_channels_b=1,1
        init_func = 'kaiming'
        init_to_identity = True

        # paras end------------
        self.oh, self.ow = height, width
        self.in_channels_a = in_channels_a
        self.in_channels_b = in_channels_b
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.offset_map = ResUnet(self.in_channels_a, self.in_channels_b, cfg='A', init_func=init_func, init_to_identity=init_to_identity).to(
            self.device)
        self.identity_grid = self.get_identity_grid()

    def get_identity_grid(self):
        x = torch.linspace(-1.0, 1.0, self.ow)
        y = torch.linspace(-1.0, 1.0, self.oh)
        xx, yy = torch.meshgrid([y, x])
        xx = xx.unsqueeze(dim=0)
        yy = yy.unsqueeze(dim=0)
        identity = torch.cat((yy, xx), dim=0).unsqueeze(0)
        return identity

    def forward(self, img_a, img_b, apply_on=None):

        deformations = self.offset_map(img_a, img_b)

        return deformations

Queation 7:

Transformer_2D的神经网络模型类,它是一个继承自nn.Module的子类。这个模型类中只定义了一个forward方法,表示该模型是一个自定义的前向传播模型。

在forward方法中,接收两个输入参数src和flow,分别表示输入的源图像和变换的光流(或其他变换信息)。首先根据flow的shape获取batch size(b)、高度(h)和宽度(w),然后创建一个大小为(h, w)的网格,并将其转换为float32类型。接着将网格grid根据flow的值进行变换,计算出新的位置new_locs。接下来根据新的位置信息对new_locs进行归一化处理,并对其进行排列和维度交换,以便后续进行grid sample操作。

在经过位置变换后,调用了PyTorch中的F.grid_sample函数,该函数的作用是根据变换后的坐标信息将原图src进行采样,生成经过变换后的图像warped。最后将warped作为函数的返回值。

需要注意的是,该模型中的具体逻辑是对输入的图像进行2D空间的变换,变换的信息通过flow表示。通过对输入图像的采样,可以得到经过变换后的图像。此外,在实际应用中,还需要考虑梯度的传播和后向传播等问题,这里的代码片段并未包含完整的模型逻辑和训练过程,因此这些方面的细节可能需要在其他部分的代码中进行补充。

class Transformer_2D(nn.Module):
    def __init__(self):
        super(Transformer_2D, self).__init__()
    # @staticmethod
    def forward(self,src, flow):
      
        b = flow.shape[0]#torch.Size([1, 2, 512, 512])
        h = flow.shape[2]
        w = flow.shape[3]
        size = (h,w)

        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = grid.to(torch.float32)
        # grid = grid.repeat(b,1,1,1).cuda()
        grid = grid.repeat(b, 1, 1, 1).cpu()
        new_locs = grid+flow#torch.Size([1, 2, 512, 512])
        shape = flow.shape[2:]
        
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
        new_locs = new_locs.permute(0, 2, 3, 1)
        new_locs = new_locs[..., [1 , 0]]#torch.Size([1, 512, 512, 2])
        #提供一个input的Tensor以及一个对应的flow-field网格(比如光流,体素流等),然后根据grid中每个位置提供的坐标信息(这里指input中pixel的坐标),将input中对应位置的像素值填充到grid指定的位置,得到最终的输出。
        warped = F.grid_sample(src,new_locs,align_corners=True,padding_mode="border")
        # ctx.save_for_backward(src,flow)
        return warped

Queation 8:

    def train(self):
        print("enter training")
        for epoch in range(self.config['epoch'] + 1, self.config['n_epochs'] + 1 + self.config['decay_epoch']):
            if epoch > self.config['n_epochs']:
                self.update_learning_rate()
            self.dataloader, self.logger = self.updata_dataloader()
            # for i in range(5):
            #     self.dataloader = self.updata_dataloader()
            #     print(len(self.dataloader))
            for i, batch in enumerate(self.dataloader):
                # Set model input
                real_A2 = Variable(self.input_A2.copy_(batch['A2']))  ###注意generator输入是[1,3,512,512]
                real_B2 = Variable(self.input_B2.copy_(batch['B2']))
                print("Hd_Trainer_x1: real_A2=", real_A2.shape, "real_B2=", real_B2.shape)
                """这里的A2和 B2分别是平扫和增强的A1 = (368, 300)   A2 = torch.Size([1, 512, 512])
                   Hd_Trainer_x1: real_A2= torch.Size([4, 1, 512, 512])  real_B2= torch.Size([4, 1, 512, 512])"""

                real_BB2 = copy.deepcopy(real_B2) # 使用deepcopy()复制列表lis_之后,直接改变二维列表中的值 d,不会影响到源列表lis_D
                # for c in range(1):#设置每个epoch生成器迭代次数
                self.optimizer_R_A.zero_grad()
                self.optimizer_G.zero_grad()
                #### regist sys loss
                fake_B = self.netG_A2B(real_A2)  ### 根据real_A2生成fake_B
                print("Hd_Trainer_x1: fake_B=", type(fake_B), fake_B.shape)
                # Hd_Trainer_x1: fake_B= <class 'torch.Tensor'> torch.Size([4, 1, 512, 512])
                """# 不太清楚这里的self.R_A和SysRegist_A2B的作用是啥???"""
                Trans = self.R_A(fake_B, real_B2)  # torch.Size([1, 2, 512, 512])
                SysRegist_A2B = self.spatial_transform(fake_B, Trans)

                SM_loss = self.config['Smooth_lamda'] * smooothing_loss(Trans)  ####smooth loss
                SR_loss = self.config['Corr_lamda1'] * self.L1_loss(SysRegist_A2B, real_B2)  ###SR                
                
                pred_fake0 = self.netD_B(fake_B) # 判别器
                adv_loss = self.config['Adv_lamda1'] * self.MSE_loss(pred_fake0, self.target_real)
                toal_loss = SM_loss + adv_loss  + SR_loss  # 是否需要取平均???
                toal_loss.backward()

                self.optimizer_R_A.step()
                self.optimizer_G.step()
                self.optimizer_D_B.zero_grad()

                with torch.no_grad():
                    fake_B = self.netG_A2B(real_A2)  ####real_A2
                # fake_B = torch.cat((fake_B1, fake_B2), 0)
                pred_fake0 = self.netD_B(fake_B)
                pred_real = self.netD_B(real_BB2)
                # loss_D_B = self.config['Adv_lamda1'] * (
                #             self.criterionGAN(pred_fake0, False) + self.criterionGAN(pred_real, True)) / D
                
                """ 为什么要做这两个操作呢?
                self.MSE_loss(pred_fake0, self.target_fake): 
                self.MSE_loss(pred_real, self.target_real): """
                loss_D_B = self.config['Adv_lamda1'] * self.MSE_loss(pred_fake0, self.target_fake) + self.config[
                    'Adv_lamda1'] * self.MSE_loss(pred_real, self.target_real) 
                loss_D_B.backward()
                self.optimizer_D_B.step()
                ###################################
                self.logger.log({'loss_D_B': loss_D_B, },
                                images={'real_A': real_A2, 'real_B': real_BB2,
                                        'fake_B': fake_B})  # ,'SR':SysRegist_A2B
                if (i + 1) % 40000 == 0:
                    st = str(0) + '_' + str(int(1 + i / 40000))
                    torch.save(self.netG_A2B.state_dict(),
                               self.config['save_root'] + "netG_A2B_x_" + st + ".pth")
                    torch.save(self.R_A.state_dict(),
                               self.config['save_root'] + "R_A_x_" + st + ".pth")
                    torch.save(self.netD_B.state_dict(),
                               self.config['save_root'] + "netD_B_x_" + st + ".pth")
            ############val###############
            if epoch%5==0:#batch>1 use small data to validate in  training
                with torch.no_grad():
                    SSIM = 0
                    PSNR=0
                    num = 0
                    for i, batch in enumerate(self.val_data):
                        real_A1 = Variable(self.input_A.copy_(batch['A2']))  ###注意generator输入是[1,3,512,512]
                        # real_A2 = Variable(self.input_A.copy_(batch['A2']))  ###注意generator输入是[1,3,512,512]
                        real_B2 = Variable(self.input_B2.copy_(batch['B2'])).detach().cpu().numpy().squeeze()
                        fake_B= self.netG_A2B(real_A1).detach().cpu().numpy().squeeze()#real_A2
                        # real_B2 = (real_B2 * 0.5 + 0.5) * 255
                        # fake_B = (fake_B * 0.5 + 0.5) * 255
                        psnr = self.PSNR(fake_B, real_B2)  # fake_B
                        PSNR += psnr
                        ssim = compare_ssim(fake_B, real_B2, multichannel=True, channel_axis=-1, win_size = 3)
                        SSIM += ssim
                        num += 1
                    print('PSNR:', PSNR / num)
                    print('SSIM:', SSIM / num)
                #         # Save models checkpoints
                if not os.path.exists(self.config["save_root"]):
                    os.makedirs(self.config["save_root"])
                st=str(epoch)+'_'+ str(round(PSNR / num, 4))+'_'+str(round(SSIM / num, 4))
                torch.save(self.netG_A2B.state_dict(),
                           self.config['save_root'] + "netG_A2B_x_"+st+ "b.pth")
                torch.save(self.R_A.state_dict(),
                           self.config['save_root'] + "R_A_x_" +st+"b.pth")
                torch.save(self.netD_B.state_dict(),
                           self.config['save_root'] + "netD_B_x_"+st+"b.pth")
            else:
                if not os.path.exists(self.config["save_root"]):
                    os.makedirs(self.config["save_root"])
                st=str(epoch)
                torch.save(self.netG_A2B.state_dict(),
                           self.config['save_root'] + "netG_A2B_x_" + st + ".pth")
                torch.save(self.R_A.state_dict(),
                           self.config['save_root'] + "R_A_x_" + st + ".pth")
                torch.save(self.netD_B.state_dict(),
                           self.config['save_root'] + "netD_B_x_" + st + ".pth")

Question9 校准器中的ResUnet

class ResUnet(torch.nn.Module):
    def __init__(self, nc_a, nc_b, cfg, init_func, init_to_identity):
        super(ResUnet, self).__init__()

        act = down_activation[cfg]
        self.ndown_blocks = 5
        self.nup_blocks = 5
        assert self.ndown_blocks >= self.nup_blocks
        in_nf = nc_a + nc_b  # in_nf = 1 + 1
        conv_num = 1
        skip_nf = {}
        for out_nf in ndf[cfg]:
            setattr(self, 'down_{}'.format(conv_num),
                    DownBlock(in_nf, out_nf, 3, 1, 1, activation=act, init_func=init_func, bias=True,
                              use_resnet=use_down_resblocks[cfg], use_norm=False))
            skip_nf['down_{}'.format(conv_num)] = out_nf
            in_nf = out_nf
            conv_num += 1
        conv_num -= 1

        if use_down_resblocks[cfg]:
            self.c1 = Conv(in_nf, 2 * in_nf, 1, 1, 0, activation=act, init_func=init_func, bias=True, use_resnet=False, use_norm=False)
            self.t = ((lambda x: x) if resnet_nblocks[cfg] == 0 else ResnetTransformer(2 * in_nf, resnet_nblocks[cfg], init_func))
            self.c2 = Conv(2 * in_nf, in_nf, 1, 1, 0, activation=act, init_func=init_func, bias=True, use_resnet=False, use_norm=False)

        # ------------- Up-sampling path
        act = up_activation[cfg]
        for out_nf in nuf[cfg]:
            setattr(self, 'up_{}'.format(conv_num),
                    Conv(in_nf + skip_nf['down_{}'.format(conv_num)], out_nf, 3, 1, 1, bias=True, activation=act,
                         init_fun=init_func, use_norm=False, use_resnet=False))
            in_nf = out_nf
            conv_num -= 1

        if refine_output[cfg]:
            self.refine = nn.Sequential(ResnetTransformer(in_nf, 1, init_func),
                                        Conv(in_nf, in_nf, 1, 1, 0, use_resnet=False, init_func=init_func,
                                             activation=act,
                                             use_norm=False))
        else:
            self.refine = lambda x: x

        self.output = Conv(in_nf, 2, 3, 1, 1, use_resnet=False, bias=True,
                           init_func=('zeros' if init_to_identity else init_func), activation=None,
                           use_norm=False)


    def forward(self, img_a, img_b):
        x = torch.cat([img_a, img_b], 1)
        skip_vals = {}
        conv_num = 1

        while conv_num <= self.ndown_blocks:
            x, skip = getattr(self, 'down_{}'.format(conv_num))(x)
            skip_vals['down_{}'.format(conv_num)] = skip
            conv_num += 1

        if hasattr(self, 't'):
            x = self.c1(x)
            x = self.t(x)
            x = self.c2(x)

        conv_num -= 1
        while conv_num > (self.ndown_blocks - self.nup_blocks):
            s = skip_vals['down_{}'.format(conv_num)]
            x = F.interpolate(x, (s.size(2), s.size(3)), mode='bilinear')
            x = torch.cat([x, s], 1)
            x = getattr(self, 'up_{}'.format(conv_num))(x)
            conv_num -= 1

        x = self.refine(x)
        x = self.output(x)
        return x

这段代码定义了一个名为 ResUnet 的类,它继承自 PyTorch 的 torch.nn.Module。这个类实现了一个基于残差连接的 U-Net 网络结构,用于图像分割等任务。

让我们逐步解释这个类的重要部分:

  1. __init__ 方法:这是类的初始化方法,用于定义网络的结构。在这个方法中:

    • 首先设置了一些网络的参数,比如下采样块数量和上采样块数量。
    • 然后通过循环构建了下采样路径,其中使用了 DownBlock 类来构建每个下采样块,并将这些块保存在类实例中。
    • 如果使用了下采样的 ResNet 块,会构建额外的卷积层 c1tc2
    • 接着构建了上采样路径,使用 Conv 类构建每个上采样块。
    • 最后,根据配置选择是否添加细化层以及输出层的构建。
  2. forward 方法:这个方法定义了数据在网络中的传播过程。主要步骤包括:

    • 将输入的图像 img_aimg_b 连接起来作为网络的输入 x
    • 通过下采样路径,将输入数据逐步下采样,并保存对应的跳跃连接的特征。
    • 如果定义了变量 t,则应用额外的卷积操作。
    • 通过上采样路径,将数据上采样,并与对应的跳跃连接的特征拼接起来。
    • 最后将数据经过细化层和输出层,最终返回输出。

整体来说,这个类实现了一个 U-Net 网络结构,通过残差连接实现了信息的跳跃传递,可以用于图像分割等任务。在实际使用时,需要根据代码中涉及到的其他类和变量进行定义和初始化。

Question10: GAN Loss 和 Smoothing Loss

class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.Tensor):#torch.cuda.FloatTensor
        super(GANLoss, self).__init__()
        self.target_real = Variable(tensor(1, 1).fill_(1.0), requires_grad=False)
        self.target_fake = Variable(tensor(1, 1).fill_(0.0), requires_grad=False)
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def __call__(self, input, target_is_real):
        if isinstance(input[0], list):#input[0]
            loss = 0
            #print(len(input))
            w=[1.8,0.2]
            # w = [1, 1]
            i=0
            for input_i in input:
                #print(len(input))
                x = input_i[-1]
                pred= F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
                if target_is_real:
                    loss += self.loss(pred, self.target_real)*w[i]
                else:
                    loss += self.loss(pred, self.target_fake)*w[i]
                i=i+1
            return loss
        else:
            x=input[-1]
            pred = F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
            if target_is_real:
                loss= self.loss(pred, self.target_real)
            else:
                loss= self.loss(pred, self.target_fake)
            return loss
        

def smooothing_loss(y_pred):
    dy = torch.abs(y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :])
    dx = torch.abs(y_pred[:, :, :, 1:] - y_pred[:, :, :, :-1])

    dx = dx*dx
    dy = dy*dy
    d = torch.mean(dx) + torch.mean(dy)
    grad = d 
    return d


def read_dicom(file_path):
    ds = pydicom.dcmread(file_path.replace('../../../','../../'), force=True)  # 读取头文件
    image2=(ds.pixel_array).astype(np.int)

    #sitk读取的数值比pydicom读取的数值小1024
    image2[image2<0]=0
    image2=image2/4095
    image2 = (image2 - 0.5)/0.5
    return image2

这段代码定义了一个名为 GANLoss 的类,用于计算 GAN(生成对抗网络) 的损失函数。同时还定义了一个名为 smoothing_loss 的函数用于计算平滑损失,以及一个名为 read_dicom 的函数用来读取 DICOM 文件。

让我们逐步解释这些部分的作用:

  1. GANLoss 类:

    • __init__ 方法初始化了 GANLoss 类。其中,use_lsgan 参数用于指定是否使用 LSGAN(最小二乘 GAN) 损失,target_real_labeltarget_fake_label 用于定义真实标签和虚假标签的目标值,默认为 1.0 和 0.0。tensor 参数用于指定张量类型,默认为 torch.Tensor
    • 在初始化过程中,根据是否使用 LSGAN,选择了使用均方误差损失(nn.MSELoss)或二元交叉熵损失(nn.BCELoss)。
    • __call__ 方法重载了 () 运算符,用于计算 GAN 损失。根据输入是否为列表,计算不同的损失值。对于每个输入,首先对输入进行处理,并计算输出的损失值。最后根据 target_is_real 标志选择真实标签还是虚假标签,并返回相应的损失值。
  2. smoothing_loss 函数:

    • 这个函数用于计算平滑损失。它计算了输入张量在垂直和水平方向上的差值,然后对这些差值进行平方处理,并计算其平均值作为平滑损失。
    • 最后返回平滑损失值 grad
  3. read_dicom 函数:

    • 这个函数用于读取 DICOM 文件,首先使用 pydicom 库读取 DICOM 文件的像素数据,并将其转换为 numpy 数组。
    • 然后将像素数据进行一系列的归一化操作,将像素值映射到 [-1, 1] 的范围内,并返回处理后的图像数据。

综合来看,这些函数和类主要用于 GAN 损失的计算和图像处理。GANLoss 类用于计算生成对抗网络中的损失,smoothing_loss 函数用于计算平滑损失,而 read_dicom 函数用于读取 DICOM 文件并对图像数据进行预处理。

11. 数据加载

def read_ori_w(file_path):#read_dicom_mw
    file_path=file_path.replace('../../../', '../../')
    dicom = sitk.ReadImage(file_path)
    data1 = np.squeeze(sitk.GetArrayFromImage(dicom))
    data=data1+1024
    # ds = pydicom.dcmread(file_path, force=True)  # 读取头文件
    # data=(ds.pixel_array).astype(np.int)
    # data1=data-1024
    #if "C+" in st:#宽对窄
    center =50# ds.WindowCenter 50
    width = 400#ds.WindowWidth # 400
    win_min = (2 * center - width) / 2.0 + 0.5#-149.5
    win_max = (2 * center + width) / 2.0 + 0.5#250.5
    dFactor = 255.0 / (win_max - win_min)#把窗内
    image = data1 - win_min #sitk读取的数值比pydicom读取的数值小1024
    # image=data1+149.5
    image1 = np.trunc(image * dFactor)#dFactor
    image1[image1>255]=255
    image1[image1<0]=0
    image1=image1/255#np.uint8(image)
    image1 = (image1 - 0.5)/0.5

    image2=data#sitk读取的数值比pydicom读取的数值小1024
    image2[image2<0]=0#-2000->0
    image2=image2/4095
    image2 = (image2 - 0.5)/0.5
    ######

    # image1=(image1*2-1)*255
    # image2=(image2*2-1)*255
    # plt.subplot(2, 2, 1)
    # plt.imshow(image1*255, cmap='gray')#,vmin=0,vmax=255
    # plt.subplot(2, 2, 2)
    # plt.imshow(image2*255, cmap='gray')#,vmin=0,vmax=255
    # plt.show()

    return image1,image2

#真实平扫头部数据
def read_dicom(file_path):
    ds = pydicom.dcmread(file_path.replace('../../../','../../'), force=True)  # 读取头文件
    image2=(ds.pixel_array).astype(np.int)

    #sitk读取的数值比pydicom读取的数值小1024
    image2[image2<0]=0
    image2=image2/4095
    image2 = (image2 - 0.5)/0.5
    return image2

实际上,在代码中用的也是image2的值,跟read_dicom相比,仅仅是使用skit读取dicom,并且对其进行+10242,不知道是啥原因?

def read_ori_w(file_path):#read_dicom_mw
    file_path=file_path.replace('../../../', '../../')
    dicom = sitk.ReadImage(file_path)
    data1 = np.squeeze(sitk.GetArrayFromImage(dicom))
    data=data1+1024
    
    image2=data#sitk读取的数值比pydicom读取的数值小1024
    image2[image2<0]=0#-2000->0
    image2=image2/4095
    image2 = (image2 - 0.5)/0.5
    
    # image1=(image1*2-1)*255
    # image2=(image2*2-1)*255
    # plt.subplot(2, 2, 1)
    # plt.imshow(image1*255, cmap='gray')#,vmin=0,vmax=255
    # plt.subplot(2, 2, 2)
    # plt.imshow(image2*255, cmap='gray')#,vmin=0,vmax=255
    # plt.show()

    return image2
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值