这篇主要介绍如dl4j如何操作csv,虽然实战中比较少用,但是对熟悉基本数据操作及结构还是有好处的,代码如下
public class BasicCSVClassifier {
private static Logger log = LoggerFactory.getLogger(BasicCSVClassifier.class);//工厂方法生成日志类
private static Map<Integer,String> eats = readEnumCSV("/DataExamples/animals/eats.csv");//用readEnumCSV方法直接读csv,存到map
private static Map<Integer,String> sounds = readEnumCSV("/DataExamples/animals/sounds.csv");
private static Map<Integer,String> classifiers = readEnumCSV("/DataExamples/animals/classifiers.csv");
public static void main(String[] args){
try {
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network//RecordReaderDataSetIterator把数据弄成DataSet对象,方便放入神经网络
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row//iris每行5个值,4个特征后跟一个类别,4是类别索引
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2//3个类别,标记为0,1,2
int batchSizeTraining = 30; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)//150个数据一次载入dataset,数据量大的时候不推荐,训练批的数量是30
DataSet trainingData = readCSVDataset(
"/DataExamples/animals/animals_train.csv",
batchSizeTraining, labelIndex, numClasses);//readCSVDataset方法直接读取csv变成DataSet数据
// this is the data we want to classify
int batchSizeTest = 44;//测试批44,跟上面一样
DataSet testData = readCSVDataset("/DataExamples/animals/animals.csv",
batchSizeTest, labelIndex, numClasses);
// make the data model for records prior to normalization, because it
// changes the data.//在规范化之前先构建数据结构,因为规范化改变了数据
Map<Integer,Map<String,Object>> animals = makeAnimalsForTesting(testData);//animals是这样的结构{0={eats=Mice, sounds=Meow, weight=10.0, yearsLived=19}, 1={eats=Cats, sounds=Bark, weight=60.0, yearsLived=9}...}
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance)://规范化数据,0均值,单位方差
DataNormalization normalizer = new NormalizerStandardize();//规范化器
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data//计算训练数据的均值方差,通过trainingData.getFeatures().mean(0)