ai模型训练思考与疑问?

ai模型训练思考与疑问

目的

  • 识别图片特征的ai模型
    • 数据集规模
      • 训练数据集:约 10,000 张图片
      • 验证数据集:约 2,000 张图片
      • 测试数据集:约 2,000 张图片
    • 模型架构
      • 使用卷积神经网络(CNN)作为基础架构,用于从图片中提取特征并进行分类。
    • 数据预处理
      • 图片尺寸标准化:将所有图片调整为相同的尺寸,以便输入模型。
      • 数据增强:采用随机旋转、缩放、平移等技术增加训练数据的多样性。
    • 模型优化
      • 学习率调度:使用学习率调度策略以优化训练过程。
      • 正则化:采用 L2 正则化等技术避免过拟合。
      • 提前停止:在验证集上监控模型性能,并在性能不再提高时停止训练。
    • 评估指标
      • 成功率:识别测试数据集中公章的准确率。
      • 其他指标:如精确率、召回率和 F1 分数等。

分析

  • 需要算力

    由于我们只需要做到简单的图片识别,所以量不会很大。

  • 需要的训练材料

    • 涉及问题:深度学习模型通常在图像的不同区域提取特征,并且期望输入具有相同的尺寸。通过将所有图片调整为相同的尺寸,可以确保模型提取的特征在不同图片之间具有一致性,从而提高模型的泛化能力和性能。所以我们只能手工提取具有相同尺寸的图片,进行ai训练。

    • 提出问题:是否可以调整图片的灰阶?例如,公章是红色的,那么我们就可以实现修改图片的灰阶实现只保留红色的部分,加快学习进度呢?

提出问题

  • 要想快速完成训练成果,需要合理调用GPU进行数据运算,我们该如何调用这方面的算力呢?

  • 假设我们成功调用了gpu,训练出了第一代模型那又该如何进行模型的部署,模型运行的环境是否有特殊的要求,或者我希望他在NPU上运行呢?

  • 什么是卷积神经网络?

    卷积神经网络(Convolutional Neural Network,CNN)是一种专门用于处理具有网格结构的数据,如图像和视频的深度学习模型。CNN 在图像识别、物体检测、语义分割等计算机视觉任务中取得了巨大成功,其主要作用包括:

    1. 特征提取
      • CNN 可以自动学习图像中的特征,从简单到复杂,通过一系列的卷积层和池化层逐渐提取出图像中的高级特征。
      • 卷积层通过滤波器(卷积核)对图像进行卷积操作,从而捕捉到图像中的局部特征,如边缘、纹理等。
      • 池化层则对卷积层输出的特征图进行降采样,减少特征图的尺寸,保留重要特征并减少计算量。
    2. 层级学习
      • CNN 由多个卷积层和池化层组成,层与层之间的连接形成了层级结构,使得模型能够逐级提取和组合特征,从而实现对输入数据的逐步抽象和理解。
      • 低层特征包含更多的局部信息,而高层特征则包含更多的全局和抽象信息。
    3. 参数共享
      • CNN 中的卷积操作采用参数共享的方式,即同一卷积核在图像的不同位置共享权重,这使得模型具有平移不变性,提高了模型的效率和泛化能力。
    4. 空间结构保持
      • CNN 在处理图像时能够保持空间结构,即考虑了像素之间的位置关系,而不是将图像视为独立的像素点,这有助于捕捉到图像中的空间局部信息。
    5. 适应性
      • CNN 可以通过端到端的训练方式,从数据中学习到合适的特征表示和模型参数,无需手工设计特征提取器。

    总的来说,卷积神经网络在图像处理领域具有很强的特征提取能力和模式识别能力,是目前最常用和有效的图像识别模型之一。

基本的代码分析

模型训练

       public static void main(String[] args) {
        try {
            int height = 28;
            int width = 28;
            int channels = 1;
            int batchSize = 64;
            int numClasses = 10;
            int numInputs = height * width;

            // 加载或者新建模型
            MultiLayerNetwork model = null;
            File savedModelFile = new File("model.zip");
            if (savedModelFile.exists()) {
                model = MultiLayerNetwork.load(savedModelFile, true);
                System.out.println("已从文件加载保存的模型:");
            } else {
                model = createModel(numInputs, numClasses);
                System.out.println("创建了新模型:");
            }

            // 加载和预处理数据
            File trainData = new File("path");
            DataSetIterator trainIter = setupData(trainData, batchSize, height, width, channels, numClasses);

            // 训练模型
            int numEpochs = 5;
            for (int i = 0; i < numEpochs; i++) {
                model.fit(trainIter);
            }

            // 保存模型
            model.save(savedModelFile, true);
            System.out.println("将模型保存到文件:");

        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static MultiLayerNetwork createModel(int numInputs, int numClasses) {
        // 网络配置
        MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
                .seed(12345)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(new Adam(0.001))
                .list()
                .layer(new DenseLayer.Builder().nIn(numInputs).nOut(100)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nIn(100).nOut(numClasses)
                        .activation(Activation.SOFTMAX)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .backpropType(BackpropType.Standard)
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(config);
        model.init();
        model.setListeners(new ScoreIterationListener(10));
        return model;
    }

    public static DataSetIterator setupData(File trainData, int batchSize, int height, int width, int channels, int numClasses) {
        DataSetIterator trainIter = null;
        try {
            // 训练数据
            FileSplit trainSplit = new FileSplit(trainData);
            ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
            ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
            trainRR.initialize(trainSplit);
            trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, numClasses);
            trainIter.setPreProcessor(new ImagePreProcessingScaler(0, 1));
        } catch (IOException  e) {
            e.printStackTrace();
        }
        return trainIter;
    }

训练所需的jar包

<dependencies>
    <!-- Deeplearning4j 核心库 -->
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-core</artifactId>
        <version>1.0.0-beta7</version> <!-- 根据需要更改版本号 -->
    </dependency>

    <!-- Deeplearning4j 数据处理库 -->
    <dependency>
        <groupId>org.datavec</groupId>
        <artifactId>datavec-api</artifactId>
        <version>1.0.0-beta7</version> <!-- 根据需要更改版本号 -->
    </dependency>

    <!-- Deeplearning4j 图像处理库 -->
    <dependency>
        <groupId>org.datavec</groupId>
        <artifactId>datavec-dataimage</artifactId>
        <version>1.0.0-beta7</version> <!-- 根据需要更改版本号 -->
    </dependency>

    <!-- Deeplearning4j MNIST 数据集库 -->
    <dependency>
        <groupId>org.deeplearning4j</groupId>
        <artifactId>deeplearning4j-datasets-data-vec</artifactId>
        <version>1.0.0-beta7</version> <!-- 根据需要更改版本号 -->
    </dependency>

    <!-- Deeplearning4j ND4J 库 -->
    <dependency>
        <groupId>org.nd4j</groupId>
        <artifactId>nd4j-native-platform</artifactId>
        <version>1.0.0-beta7</version> <!-- 根据需要更改版本号 -->
    </dependency>
</dependencies>

图像分析技术

JavaFX 来将图片转换为灰度图像并只保留红色通道的信息

public class RedChannelExtractor {

    public static void main(String[] args) {
        // 输入图片路径
        String inputImagePath = "input_image.jpg";

        // 提取红色通道
        try {
            BufferedImage inputImage = ImageIO.read(new File(inputImagePath));
            BufferedImage redChannelImage = extractRedChannel(inputImage);
            
            // 保存提取后的图片
            File outputImageFile = new File("red_channel_image.jpg");
            ImageIO.write(redChannelImage, "jpg", outputImageFile);
            System.out.println("Red channel extracted and saved successfully.");
        } catch (IOException e) {
            System.out.println("Error while extracting red channel: " + e.getMessage());
        }
    }

    // 提取红色通道的方法
    private static BufferedImage extractRedChannel(BufferedImage inputImage) {
        // 创建新的 BufferedImage,大小和输入图片一样,但只有一个通道
        BufferedImage redChannelImage = new BufferedImage(inputImage.getWidth(), inputImage.getHeight(), BufferedImage.TYPE_BYTE_GRAY);

        // 提取红色通道的信息
        for (int y = 0; y < inputImage.getHeight(); y++) {
            for (int x = 0; x < inputImage.getWidth(); x++) {
                int rgb = inputImage.getRGB(x, y);
                int red = (rgb >> 16) & 0xFF; // 提取红色通道信息
                // 在新的灰度图像中将红色通道信息写入对应像素位置
                redChannelImage.setRGB(x, y, (red << 16) | (red << 8) | red);
            }
        }

        return redChannelImage;
    }
}

使用 JavaFX 来将图片调整为相同的尺寸需要将获取的图片处理为相同的尺寸

public class ImageResizer {

    public static void main(String[] args) {
        // 输入图片路径和目标尺寸
        String inputImagePath = "input_image.jpg";
        int targetWidth = 200;
        int targetHeight = 200;

        // 调整图片尺寸
        try {
            BufferedImage inputImage = ImageIO.read(new File(inputImagePath));
            Image resizedImage = resizeImage(inputImage, targetWidth, targetHeight);
            
            // 保存调整后的图片
            File outputImageFile = new File("resized_image.jpg");
            ImageIO.write(SwingFXUtils.fromFXImage(resizedImage, null), "jpg", outputImageFile);
            System.out.println("Resized image saved successfully.");
        } catch (IOException e) {
            System.out.println("Error while resizing image: " + e.getMessage());
        }
    }

    // 调整图片尺寸的方法
    private static Image resizeImage(BufferedImage inputImage, int targetWidth, int targetHeight) {
        ImageView imageView = new ImageView(SwingFXUtils.toFXImage(inputImage, null));
        imageView.setFitWidth(targetWidth);
        imageView.setFitHeight(targetHeight);
        imageView.setPreserveRatio(true);
        imageView.setSmooth(true);
        imageView.setCache(true);
        return imageView.snapshot(null, null);
    }
}

使用

 public static void main(String[] args) {
        try {
            // 加载已训练好的模型
            File savedModelFile = new File("model.zip");
            MultiLayerNetwork model = MultiLayerNetwork.load(savedModelFile, true);

            // 加载待预测的图像
            File imageFile = new File("path");
            NativeImageLoader loader = new NativeImageLoader(28, 28, 1);
            INDArray image = loader.asMatrix(imageFile);
            ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1);
            scaler.transform(image);

            // 进行预测
            INDArray output = model.output(image);

            // 打印预测结果
            System.out.println("打印预测结果:");
            for (int i = 0; i < output.length(); i++) {
                System.out.println("类别 " + i + ": 可能性 = " + output.getDouble(i));
            }

            // 打印最终预测类别
            int predictedClass = model.predict(image)[0];
            System.out.println("预测类别: " + predictedClass);

        } catch (IOException e) {
            e.printStackTrace();
        }
    }
  • 8
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值