tensorflow2的utils.Sequence

假设文件是这样的

images和labels里面保存的都是.npy数组

images里面的一个数据的shape=[128,128,16,1],labels里面的一个数据的shape=[128,128,16,2],因为是二分类语义分割

data_loader.py

from tensorflow.keras.utils import Sequence
import numpy as np
import math
 
 
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.

 
class seg3D_Sequence(Sequence):
    def __init__(self,file_name_list, batch_size,image_path="images/",
                                                label_path="labels/"):
        self.file_name_list = file_name_list
        self.batch_size = batch_size
        self.image_path = image_path
        self.label_path = label_path        
 
    def __len__(self):
        return math.ceil(len(self.file_name_list) / self.batch_size)
 
    def __getitem__(self, idx):
        self.x = []
        self.y = []
        for file_name in self.file_name_list:
            self.x.append(self.image_path+file_name)
            self.y.append(self.label_path+file_name)

        batch_x = self.x[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) *
        self.batch_size]
 
        x_re = [np.load(file_path) for file_path in batch_x]
        y_re = [np.load(file_path) for file_path in batch_y]
 
        return np.array(x_re),np.array(y_re)
    
    def on_epoch_end(self):
        np.random.shuffle(self.file_name_list)

 train.py

import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from utils.data_loader import seg3D_Sequence
from tensorflow.keras import Model,Sequential
from tensorflow.keras.layers import Conv3D,Input
import numpy as np


if __name__ == "__main__":
    print('-'*60)
    num_classes = 2
    batch_size = 7
    train_val_split = 0.2
    image_path = "data/images/"
    
    file_name_list = os.listdir(image_path)
    train_name_list = file_name_list[:int(len(file_name_list)*0.8)]
    val_name_list = file_name_list[int(len(file_name_list)*0.8):]
    # print(len(file_name_list)) # 360

    train_data_loader = seg3D_Sequence(train_name_list,batch_size)
    val_data_loader = seg3D_Sequence(val_name_list,batch_size)
    # x,y = data_loader[90]
    # print(x.shape,y.shape)

    model = Sequential()
    model.add(Conv3D(num_classes,1,activation='sigmoid'))
    inputs = Input(shape=[128,128,16,1])
    outputs = model(inputs)
    print(outputs.shape)

    model = Model(inputs,outputs,name='test3d')
    model.summary()

    model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy']
                  )

    model.fit(train_data_loader,
                epochs=3,
                validation_data=val_data_loader)


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值