卷积文本分类(gpu)实现--deeplearning4j

前面介绍用卷积训练文本分类模型,但是算法是cpu上跑的,涉及到大数据,cpu上是跑不动的,代码在之前的博客里面可以看到,本博客主要记录在gpu上跑碰到的坑。


gpu版本信息:

root@image-ubuntu:~# nvidia-smi 

Fri Jul 14 01:21:46 2017       

+-----------------------------------------------------------------------------+

| NVIDIA-SMI 375.51                 Driver Version: 375.51                    |

|-------------------------------+----------------------+----------------------+

| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |

| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |

|===============================+======================+======================|

|   0  Tesla M60           Off  | 0000:00:02.0     Off |                  Off |

| N/A   52C    P0    46W / 150W |   3448MiB /  8123MiB |     21%      Default |

+-------------------------------+----------------------+----------------------+

                                                                               

+-----------------------------------------------------------------------------+

| Processes:                                                       GPU Memory |

|  GPU       PID  Type  Process name                               Usage      |

|=============================================================================|

|    0     53395    C   java                                          3438MiB |

+-----------------------------------------------------------------------------+



训练过程:



01:17:19.580 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6942802230659779
01:17:24.036 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6964763564002254
01:17:28.767 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6825513419103831
01:17:33.190 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6779352336198492
01:17:37.119 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6845303732693183
01:17:41.313 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6806721112942656
01:17:44.184 [ParallelWrapper trainer 0] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 200 is 0.6814421662117872
01:17:45.113 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6701808819009931
01:17:49.053 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6650278905527942
01:17:53.505 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6565601670736454
01:17:58.273 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6697584503102176
01:18:02.572 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6510347362144552
01:18:06.458 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6611058336565505
01:18:10.226 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6541094549663357
01:18:13.547 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6428940961803716
01:18:17.627 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6382708927005436
01:18:18.026 [ParallelWrapper trainer 0] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 300 is 0.6395620244327883
01:18:22.073 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6311690317350285
01:18:25.798 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6441202013287363
01:18:29.440 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6297861390295019
01:18:33.594 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6392450155730185
01:18:38.271 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6228116943748379
01:18:42.867 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6184209858527969
01:18:46.715 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6133259157463684
01:18:50.564 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6016078467137244
01:18:52.438 [ParallelWrapper trainer 0] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 400 is 0.6253361305693586
01:18:54.943 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6150038531894072
01:18:58.971 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6179796273714999
01:19:03.511 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.5815389255352973
01:19:07.499 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6121308310206943
01:19:11.571 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.5908273267756271
01:19:16.073 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6028526854103197
01:19:20.606 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.5953763587640233
01:19:24.801 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.5793217319992398
01:19:27.736 [ParallelWrapper trainer 0] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 500


报错如下:

   
   
  1. Exception in thread "main" java.lang.UnsupportedClassVersionError: org/deeplearning4j/parallelism/ParallelWrapper$Builder : Unsupported major.minor version 52.0
  2. at java.lang.ClassLoader.defineClass1(Native Method)
  3. at java.lang.ClassLoader.defineClass(ClassLoader.java:800)
  4. at java.security.SecureClassLoader.defineClass(SecureClassLoader.java:142)
  5. at java.net.URLClassLoader.defineClass(URLClassLoader.java:449)
  6. at java.net.URLClassLoader.access$100(URLClassLoader.java:71)
  7. at java.net.URLClassLoader$1.run(URLClassLoader.java:361)
  8. at java.net.URLClassLoader$1.run(URLClassLoader.java:355)
  9. at java.security.AccessController.doPrivileged(Native Method)
  10. at java.net.URLClassLoader.findClass(URLClassLoader.java:354)
  11. at java.lang.ClassLoader.loadClass(ClassLoader.java:425)
  12. at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:308)
  13. at java.lang.ClassLoader.loadClass(ClassLoader.java:358)
  14. at com.dianping.deeplearning.test.TestWithGPU.main(TestWithGPU.java:115)

主要是因为jdk的版本过低的原因,解决方案:

   
   
  1. 把java换成1.8的版本


报错如下:
    
    
  1. INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
  2. 03:57:22.643 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
  3. 03:57:28.993 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
  4. 03:57:35.097 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
主要是用于gpu训练时候精度损失的问题,解决如下:
     
     
  1. DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);


报错如下:
     
     
  1. Exception in thread "main" java.lang.RuntimeException: Exception thrown in base iterator
  2. at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator.next(AsyncDataSetIterator.java:247)
  3. at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator.next(AsyncDataSetIterator.java:33)
  4. at org.deeplearning4j.parallelism.ParallelWrapper.fit(ParallelWrapper.java:379)
  5. at com.dianping.deeplearning.cnn.TrainAdxCnnModelWithGPU.main(TrainAdxCnnModelWithGPU.java:170)
  6. Caused by: org.nd4j.linalg.exception.ND4JIllegalStateException: Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
  7. at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4776)
  8. at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3997)
  9. at org.nd4j.linalg.api.ndarray.BaseNDArray.create(BaseNDArray.java:1906)
  10. at org.nd4j.linalg.api.ndarray.BaseNDArray.subArray(BaseNDArray.java:2064)
  11. at org.nd4j.linalg.api.ndarray.BaseNDArray.get(BaseNDArray.java:4015)
  12. at com.dianping.deeplearning.cnn.CnnSentenceDataSetIterator.next(CnnSentenceDataSetIterator.java:222)
  13. at com.dianping.deeplearning.cnn.CnnSentenceDataSetIterator.next(CnnSentenceDataSetIterator.java:155)
  14. at com.dianping.deeplearning.cnn.CnnSentenceDataSetIterator.next(CnnSentenceDataSetIterator.java:25)
  15. at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator$IteratorRunnable.run(AsyncDataSetIterator.java:322)

把featuresMask 设置为null既可以





最后附上训练gpu的代码:

package com.dianping.deeplearning.cnn;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.UnsupportedEncodingException;
import java.util.List;
import java.util.Random;

import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class TrainAdxCnnModelWithGPU {
	public static void main(String[] args) throws FileNotFoundException,
			UnsupportedEncodingException {

		/*
		 * gpu训练设置
		 */

		System.out.println("。。。。。。。gpu初始化即将开始。。。。。。。。。");
		// PLEASE NOTE: For CUDA FP16 precision support is available
        DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);

        // temp workaround for backend initialization

        CudaEnvironment.getInstance().getConfiguration()
            // key option enabled
            .allowMultiGPU(true)

            // we're allowing larger memory caches
            .setMaximumDeviceCache(2L * 1024L * 1024L * 1024L)

            // cross-device access is used for faster model averaging over pcie
            .allowCrossDeviceAccess(true);
		
		
		
		System.out.println("。。。。。。。。。gpu初始化即将结束。。。。。。。。。。");

		String WORD_VECTORS_PATH = "/home/zhoumeixu/model/word2vec.model";
		// 基础配置
		int batchSize = 128;
		int vectorSize = 15; // 词典向量的维度,这边是100
		int nEpochs = 15000; // 重复多少次
		int iterator = 1;// 迭代多少次
		int truncateReviewsToLength = 256; // 词长大于256则抛弃
		int cnnLayerFeatureMaps = 100; // 卷积神经网络特征图标 / channels / CNN每层layer的深度
		PoolingType globalPoolingType = PoolingType.MAX;
		Random rng = new Random(100); // 随机抽样

		// 设置网络配置->我们有多个卷积层,每个带宽3,4,5的滤波器

		ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
				.weightInit(WeightInit.RELU)
				.activation(Activation.LEAKYRELU)
				.updater(Updater.ADAM)
				.convolutionMode(ConvolutionMode.Same)
				// This is important so we can 'stack' the results later
				.regularization(true)
				.l2(0.0001)
				.iterations(iterator)
				.learningRate(0.01)
				.graphBuilder()
				.addInputs("input")
				.addLayer(
						"cnn3",
						new ConvolutionLayer.Builder()
								.kernelSize(3, vectorSize)
								.stride(1, vectorSize).nIn(1)
								.nOut(cnnLayerFeatureMaps).build(), "input")
				.addLayer(
						"cnn4",
						new ConvolutionLayer.Builder()
								.kernelSize(4, vectorSize)
								.stride(1, vectorSize).nIn(1)
								.nOut(cnnLayerFeatureMaps).build(), "input")
				.addLayer(
						"cnn5",
						new ConvolutionLayer.Builder()
								.kernelSize(5, vectorSize)
								.stride(1, vectorSize).nIn(1)
								.nOut(cnnLayerFeatureMaps).build(), "input")
				.addVertex("merge", new MergeVertex(), "cnn3", "cnn4", "cnn5")
				// Perform depth concatenation
				.addLayer(
						"globalPool",
						new GlobalPoolingLayer.Builder().poolingType(
								globalPoolingType).build(), "merge")
				.addLayer(
						"out",
						new OutputLayer.Builder()
								.lossFunction(LossFunctions.LossFunction.MCXENT)
								.activation(Activation.SOFTMAX)
								.nIn(3 * cnnLayerFeatureMaps).nOut(2).build(),
						"globalPool").setOutputs("out").build();

		ComputationGraph net = new ComputationGraph(config);
		net.init();
		
		

		 // ParallelWrapper will take care of load balancing between GPUs.
       ParallelWrapper wrapper = new ParallelWrapper.Builder(net)
           // DataSets prefetching options. Set this value with respect to number of actual devices
           .prefetchBuffer(24)

           // set number of workers equal or higher then number of available devices. x1-x2 are good values to start with
           .workers(4)

           // rare averaging improves performance, but might reduce model accuracy
           .averagingFrequency(3)

           // if set to TRUE, on every averaging model score will be reported
           .reportScoreAfterAveraging(true)

           // optinal parameter, set to false ONLY if your system has support P2P memory access across PCIe (hint: AWS do not support P2P)
           .useLegacyAveraging(true)

           .build();
       
		net.setListeners(new ScoreIterationListener(100));

	

		// 加载向量字典并获取训练集合测试集的DataSetIterators
		System.out
				.println("Loading word vectors and creating DataSetIterators");
		/*
		 * WordVectors wordVectors = WordVectorSerializer
		 * .fromPair(WordVectorSerializer.loadTxt(new File(
		 * WORD_VECTORS_PATH)));
		 */
		WordVectors wordVectors = WordVectorSerializer
				.readWord2VecModel(WORD_VECTORS_PATH);

		DataSetIterator trainIter = getDataSetIterator(true, wordVectors,
				batchSize, truncateReviewsToLength, rng);
		DataSetIterator testIter = getDataSetIterator(false, wordVectors,
				batchSize, truncateReviewsToLength, rng);

		System.out.println("Starting training");
		for (int i = 0; i < nEpochs; i++) {
			wrapper.fit(trainIter);
			trainIter.reset();

			// 进行网络演化(进化)获得网络判定参数
			Evaluation evaluation = net.evaluate(testIter);

			testIter.reset();
			System.out.println(evaluation.stats());
			
			System.out.println("。。。。。。。第"+i+"。。。。。。。。步已经完成。。。。。。。。。。");
		}
		/*
		 * 保存模型
		 */
		saveNet("/home/zhoumeixu/model/cnn.model", net);

		/*
		 * 加载模型
		 */
		ComputationGraph netload = loadNet("/home/zhoumeixu/model/cnn.model");

		// 训练之后:加载一个句子并输出预测
		String contentsFirstPas = "我的 手机 是 手机号码";

		INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator) testIter)
				.loadSingleSentence(contentsFirstPas);
		INDArray predictionsFirstNegative = netload
				.outputSingle(featuresFirstNegative);
		List<String> labels = testIter.getLabels();

		System.out.println("\n\nPredictions for first negative review:");
		for (int i = 0; i < labels.size(); i++) {
			System.out.println("P(" + labels.get(i) + ") = "
					+ predictionsFirstNegative.getDouble(i));
		}

	}

	private static DataSetIterator getDataSetIterator(boolean isTraining,
			WordVectors wordVectors, int minibatchSize, int maxSentenceLength,
			Random rng) {
		String path = isTraining ? "/home/zhoumeixu/model/rnnsenec.txt" : "/home/zhoumeixu/model/rnnsenectest.txt";
		LabeledSentenceProvider sentenceProvider = new LabeledSentence(path,
				rng);

		return new CnnSentenceDataSetIterator.Builder()
				.sentenceProvider(sentenceProvider).wordVectors(wordVectors)
				.minibatchSize(minibatchSize)
				.maxSentenceLength(maxSentenceLength)
				.useNormalizedWordVectors(false).build();
	}

	public static void saveNet(String path, ComputationGraph net) {

		ObjectOutputStream objectOutputStream = null;
		try {
			objectOutputStream = new ObjectOutputStream(new FileOutputStream(
					path));

			objectOutputStream.writeObject(net);

			objectOutputStream.close();

		} catch (Exception e) {
			e.printStackTrace();
		}

	}

	public static ComputationGraph loadNet(String path) {
		ObjectInputStream objectInputStream = null;
		ComputationGraph net = null;
		try {
			objectInputStream = new ObjectInputStream(new FileInputStream(path));
			net = (ComputationGraph) objectInputStream.readObject();
			objectInputStream.close();

		} catch (Exception e) {

		}
		return net;
	}

}



  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值