mindspore-ValueError: The data pipeline is not a tree (i.e., one node has

from parameter import args
from mindspore import context,Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
import mindspore.dataset as ds
import mindspore.nn as nn
import mindspore.dataset.vision.c_transforms as c_transforms
import mindspore.dataset.transforms.c_transforms as c
from mindspore import dtype as mstype
from model.net import Net

# 设置运行模式,默认在GPU上进行训练
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.target_device)

sampler = ds.SequentialSampler(num_samples=5)
dataset = ds.Cifar10Dataset(args.train_data_dir, sampler=sampler)


def creat_data(training=True, repeat=1):
    """
    在训练之前,对数据集先进行处理
    :param training: 是否进行训练的标志位
    :param repeat: 是否对数据集进行复制填充
    :return: 返回处理后的数据集
    """
    resize_height = 224
    resize_width = 224
    rescale = 1.0 / 255.0
    shift = 0.0

    # 图像增强:扩充数据集,增强训练数据的泛化能力

    # 将图像以一定的概率进行垂直翻转
    random_flip = c_transforms.RandomHorizontalFlip(prob=0.5)

    #将图像进行正则化,归一化
    normalize_op = c_transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

    #将彩色图片的通道进行转换
    swap_channel = c_transforms.HWC2CHW()

    #对label进行格式转换,将其转换成指定类型格式
    type_cast_op = c.TypeCast(mstype.int32)

    trans_opt =[]
    #针对训练使用不同操作,对数据集进行不同操作增强
    if training:
        trans_opt =[random_flip]
    trans_opt +=[swap_channel,normalize_op,swap_channel]

    #将标签转换成指定类型的格式
    label_deal = dataset.map(operations=type_cast_op,input_columns='label')
    #对数据集图像进行增强
    dataset_deal = label_deal.map(operations=trans_opt,input_columns='image')

    #将数据集进行打乱处理
    dataset_deal = dataset_deal.shuffle(buffer_size=10)
    #将数据集中的数据分批取出
    dataset_deal = dataset_deal.batch(batch_size=args.batch_size,drop_remainder=True)
    #将数据集进行复制一次
    dataset_deal = dataset_deal.repeat(repeat)

    return dataset_deal


# 定义损失函数,采用交叉熵损失函数,输出损失函数的平均值
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
#引入定义的网络模型
net =Net()
# 定义优化器
optimizer = nn.Adam(params=net.get_parameters(),learning_rate=args.lr)
#训练过程的模型接口
model = Model(network=net,loss_fn=loss,optimizer=optimizer,metrics={'acc'})
data_batch_size = creat_data().get_batch_size()

#训练权重保存设置,每隔data_batch_size步保存一次模型,最多保存20个权重文件
cnfig_ck = CheckpointConfig(save_checkpoint_steps=data_batch_size,keep_checkpoint_max=20)
cekpot = ModelCheckpoint(prefix='training_weight',directory='./model/weight',config=cnfig_ck)
loss_cb = LossMonitor()
model.train(epoch=args.epoch,train_dataset=creat_data(),callbacks=[cekpot,loss_cb])
Traceback (most recent call last):
  File "D:/pythonpro/img_classify/train.py", line 79, in <module>
    model.train(epoch=args.epoch,train_dataset=creat_data(),callbacks=[cekpot,loss_cb])
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\train\model.py", line 612, in train
    dataset_size = train_dataset.get_dataset_size()
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 1496, in get_dataset_size
    runtime_getter = self.__init_size_getter()
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 1438, in __init_size_getter
    ir_tree, api_tree = self.create_ir_tree()
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 174, in create_ir_tree
    ir_tree = dataset.parse_tree()
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 230, in parse_tree
    ir_children = [d.parse_tree() for d in self.children]
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 230, in <listcomp>
    ir_children = [d.parse_tree() for d in self.children]
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 230, in parse_tree
    ir_children = [d.parse_tree() for d in self.children]
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 230, in <listcomp>
    ir_children = [d.parse_tree() for d in self.children]
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 230, in parse_tree
    ir_children = [d.parse_tree() for d in self.children]
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 230, in <listcomp>
    ir_children = [d.parse_tree() for d in self.children]
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 230, in parse_tree
    ir_children = [d.parse_tree() for d in self.children]
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 230, in <listcomp>
    ir_children = [d.parse_tree() for d in self.children]
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 230, in parse_tree
    ir_children = [d.parse_tree() for d in self.children]
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 230, in <listcomp>
    ir_children = [d.parse_tree() for d in self.children]
  File "E:\Anaconda3\envs\mindspore\lib\site-packages\mindspore\dataset\engine\datasets.py", line 229, in parse_tree
    raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
ValueError: The data pipeline is not a tree (i.e., one node has 2 consumers)
第一次使用mindspore,个人使用的是cifar10数据集,用的是自定义的网络,
训练出现这个错误 The data pipeline is not a tree (i.e., one node has 2 consumers),还请帮忙解答一下
自定义的网络如下
import mindspore.nn as nn

class Net(nn.Cell):
    def __init__(self):
        super(Net,self).__init__()
        layer=[
            nn.Conv2d(in_channels=3,out_channels=32,kernel_size=5,stride=1),
            nn.MaxPool2d(kernel_size=2,stride=1),
            nn.Conv2d(in_channels=32,out_channels=32,kernel_size=5,stride=1),
            nn.MaxPool2d(kernel_size=2,stride=1),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size=5,stride=1),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Dense(1024,64),
            nn.Dense(64,10)
        ]
        self.build_block = nn.SequentialCell(layer)


    def construct(self,x):
        x = self.build_block(x)
        return x

可以参考下mindspore的FAQ进行检查:[url]https://www.mindspore.cn/docs/faq/zh-CN/r1.5/data_processing.html[/url] Q: 当错误提示 “The data pipeline is not a tree (i.e., one node has 2 consumers)” 应该怎么检查? A: 上述错误通常是脚本书写错误导致,具体发生在下面这种场景;正常情况下数据处理pipeline中的操作是依次串联的,下面的异常场景中dataset1有两个消费节点 dataset2和dataset3,就会出现上述错误。 ``` dataset2 = dataset1.map(***) dataset3 = dataset1.map(***) ``` 正确的写法如下所示,dataset3是由dataset2进性数据增强得到的,而不是在dataset1基础上进行数据增强操作得到。 ``` dataset2 = dataset1.map(***) dataset3 = dataset2.map(***) ```

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值