DJL——java深度学习框架学习笔记——自定义数据集

自定义数据集

作为java深度学习框架,进行深度学习的时候,首先重要的是数据集,只有有了数据,才可以对自己的模型进行训练。
我这里采用的是人脸检测,这是数据集 CelebA 点击下载

DJL基础依赖

  		<dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.9.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.9.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>0.9.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.9.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.7.1</version>
        </dependency>

这是DJL的基础库依赖,只有有了这些,我们才可以进行深度学习,底层为aws开发的C++,和C的调用,但是由于是aws维护,拿来用,它不香吗?

数据集创建

这是官方的介绍

DJL中的数据集代表原始数据和加载过程。RandomAccessDataset实现了Dataset接口,并提供了全面的数据加载功能。RandomAccessDataset还是支持使用索引对数据进行随机访问的基本数据集。您可以通过扩展RandomAccessDataset轻松自定义自己的数据集
官网地址

我这里介绍的主要是创建自定义数据集

创建自定义数据集,需要集成RandomAccessDataset对象,重写他的一些方法
废话不多说,直接上代码

package com.face.demo.utlis;

import ai.djl.Application;
import ai.djl.basicdataset.BasicDatasets;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.Resource;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;
import ai.djl.util.Progress;
import cn.hutool.core.io.file.FileReader;
import com.face.demo.pojo.FaceInfo;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.reflect.TypeToken;
import com.sun.imageio.plugins.common.ImageUtil;
import org.apache.commons.csv.CSVRecord;

import java.io.*;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public class FaceDataSet  extends RandomAccessDataset {

    private static final String VERSION = "1.0";
    private static final String ARTIFACT_ID = "banana";
    private final Usage usage;
    private final Image.Flag flag;
    private final List<Path> imagePaths;
    private final List<float[]> labels;
    private final Resource resource;
    private boolean prepared;

    public FaceDataSet(FaceDataSet.Builder builder) {
        super(builder);
        this.usage = builder.usage;
        this.flag = builder.flag;
        this.imagePaths = new ArrayList();
        this.labels = new ArrayList();
        MRL mrl = MRL.dataset(Application.CV.ANY, builder.groupId, builder.artifactId);
        this.resource = new Resource(builder.repository, mrl, "1.0");
    }

    public static FaceDataSet.Builder builder() {
        return new FaceDataSet.Builder();
    }

    @Override
    public Record get(NDManager manager, long index) throws IOException {
        int idx = Math.toIntExact(index);
        NDList d = new NDList(new NDArray[]{ImageFactory.getInstance().fromFile((Path) this.imagePaths.get(idx)).toNDArray(manager, this.flag)});
        NDArray label = manager.create((float[]) this.labels.get(idx));
        NDList l = new NDList(new NDArray[]{label.reshape((new Shape(new long[]{1L})).addAll(label.getShape()))});
        return new Record(d, l);
    }

    @Override
    protected long availableSize() {
        return (long) this.imagePaths.size();
    }

    @Override
    public void prepare(Progress progress) throws IOException, TranslateException {

        if (!this.prepared) {

            Path usagePath = Paths.get("C:\\Users\\mzp\\Documents\\img_celeba.7z\\img_celeba\\img_celeba");
            FileReader fileReader = new FileReader("C:\\Users\\mzp\\Documents\\Anno\\list_bbox_celeba.txt");
            List<String> strings = fileReader.readLines();
            strings.remove(0);
            strings.remove(0);
            strings.forEach((s) -> {
                String[] s1 = s.split("\\s+");
                FaceInfo faceInfo = new FaceInfo(s1);
                float[] labelArray = new float[5];
                labelArray[0] = 0.0f;
                float[] normalized = Normalized(faceInfo);

                labelArray[1] = (Float) normalized[0];
                labelArray[2] = (Float) normalized[1];
                labelArray[3] = (Float) normalized[2];
                labelArray[4] = (Float) normalized[3];

                this.imagePaths.add(usagePath.resolve(faceInfo.getImage_id()));
                this.labels.add(labelArray);

            });

            this.prepared = true;
        }

    }

    public static final class Builder extends BaseBuilder<FaceDataSet.Builder> {
        Repository repository;
        String groupId;
        String artifactId;
        Usage usage;
        Image.Flag flag;

        Builder() {
            this.repository = BasicDatasets.REPOSITORY;
            this.groupId = "ai.djl.basicdataset";
            this.artifactId = "face";
            this.usage = Usage.TRAIN;
            this.flag = Image.Flag.COLOR;
        }

        public FaceDataSet.Builder self() {
            return this;
        }

        public FaceDataSet.Builder optUsage(Usage usage) {
            this.usage = usage;
            return this.self();
        }

        public FaceDataSet.Builder optRepository(Repository repository) {
            this.repository = repository;
            return this.self();
        }

        public FaceDataSet.Builder optGroupId(String groupId) {
            this.groupId = groupId;
            return this;
        }

        public FaceDataSet.Builder optArtifactId(String artifactId) {
            if (artifactId.contains(":")) {
                String[] tokens = artifactId.split(":");
                this.groupId = tokens[0];
                this.artifactId = tokens[1];
            } else {
                this.artifactId = artifactId;
            }

            return this;
        }

        public FaceDataSet.Builder optFlag(Image.Flag flag) {
            this.flag = flag;
            return this.self();
        }

        public FaceDataSet build() {
            if (this.pipeline == null) {
                this.pipeline = new Pipeline(new Transform[]{new ToTensor()});
            }

            return new FaceDataSet(this);
        }
    }

    public  float[] Normalized(FaceInfo faceInfo) {
        File file = new File(faceInfo.getImageURL());
        try {
            FileInputStream fileInputStream = new FileInputStream(file);
            Image image = ImageFactory.getInstance().fromInputStream(fileInputStream);
            float dw = 1.f / image.getWidth();
            float dh = 1.f / image.getHeight();
            float x_1 = Float.parseFloat(faceInfo.getX_1());
            float y_1 = Float.parseFloat(faceInfo.getY_1());
            float width = Float.parseFloat(faceInfo.getWidth());
            float height = Float.parseFloat(faceInfo.getHeight());

            float x = (x_1 + y_1) / 2.0f;
            float y = (width + height) / 2.0f;
            float w = y_1 - x_1;
            float h = height - width;
            x = x * dw;
            w = w * dw;
            y = y * dh;
            h = h * dh;
            float[] floats = new float[4];
            floats[0] = x;
            floats[1] = w;
            floats[2] = y;
            floats[3] = h;
            return floats;
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return null;
    }
}

  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

老马识途__

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值