Encog中有很多的训练方法。
EncogUtility是一个功能辅助类,提供了很多方便的函数
Modifier and Type | Method and Description |
---|---|
static double | calculateClassificationError(MLClassification method, MLDataSet data)
Calculate the classification error.
|
static double | calculateRegressionError(MLRegression method, MLDataSet data) |
static void | convertCSV2Binary(File csvFile, CSVFormat format, File binFile, int[] input, int[] ideal, boolean headers) |
static void | convertCSV2Binary(File csvFile, File binFile, int inputCount, int outputCount, boolean headers)
Convert a CSV file to a binary training file.
|
static void | convertCSV2Binary(String csvFile, String binFile, int inputCount, int outputCount, boolean headers)
Convert a CSV file to a binary training file.
|
static void | evaluate(MLRegression network, MLDataSet training)
Evaluate the network and display (to the console) the output for every value in the training set.
|
static void | explainErrorMSE(MLRegression method, MatrixMLDataSet training) |
static void | explainErrorRMS(MLRegression method, MatrixMLDataSet training) |
static String | formatNeuralData(MLData data)
Format neural data as a list of numbers.
|
static MLDataSet | loadCSV2Memory(String filename, int input, int ideal, boolean headers, CSVFormat format, boolean significance)
Load CSV to memory.
|
static MLDataSet | loadEGB2Memory(File filename) |
static void | saveCSV(File targetFile, CSVFormat format, MLDataSet set) |
static void | saveEGB(File f, MLDataSet data)
Save a training set to an EGB file.
|
static BasicNetwork | simpleFeedForward(int input, int hidden1, int hidden2, int output, boolean tanh)
Create a simple feedforward neural network.
|
static void | trainConsole(BasicNetwork network, MLDataSet trainingSet, int minutes)
Train the neural network, using SCG training, and output status to the console.
|
static void | trainConsole(MLTrain train, BasicNetwork network, MLDataSet trainingSet, int minutes)
Train the network, using the specified training algorithm, and send the output to the console.
|
static void | trainToError(MLMethod method, MLDataSet dataSet, double error)
Train the method, to a specific error, send the output to the console.
|
static void | trainToError(MLTrain train, double error)
Train to a specific error, using the specified training method, send the output to the console.
|
BasicTraining类是所有训练方法类的父类
构造函数 |
---|
BasicTraining()
Used for serialization.
|
BasicTraining(TrainingImplementationType implementationType) |
返回值 | 成员函数 |
---|---|
void | addStrategy(Strategy strategy)
Training strategies can be added to improve the training results.
|
void | finishTraining()
Should be called after training has completed and the iteration method will not be called any further.
|
double | getError() |
TrainingImplementationType | getImplementationType() |
int | getIteration() |
List<Strategy> | getStrategies() |
MLDataSet | getTraining() |
boolean | isTrainingDone() |
void | iteration(int count)
Perform the specified number of training iterations.
|
void | postIteration()
Call the strategies after an iteration.
|
void | preIteration()
Call the strategies before an iteration.
|
void | setError(double error) |
void | setIteration(int iteration)
Set the current training iteration.
|
void | setTraining(MLDataSet training)
Set the training object that this strategy is working with.
|
Backpropagation类是propagation类的子类
构造函数 |
---|
Backpropagation(ContainsFlat network, MLDataSet training)
Create a class to train using backpropagation.
|
第一个参数:将被训练的网络 第二个参数: 训练集 第三个参数:学习率 第四个参数: 梯度下降法中一种常用的加速技术。momentum是加速系数,momentum=0表示无加速,值越大表示加速越快。 |
返回值 | 成员函数 |
---|---|
boolean | canContinue() |
double[] | getLastDelta() |
double | getLearningRate() |
double | getMomentum() |
void | initOthers()
Perform training method specific init.
|
boolean | isValidResume(TrainingContinuation state)
Determine if the specified continuation object is valid to resume with.
|
TrainingContinuation | pause()
Pause the training.
|
void | resume(TrainingContinuation state)
Resume training.
|
void | setLearningRate(double rate)
Set the learning rate, this is value is essentially a percent.
|
void | setMomentum(double m)
Set the momentum for training.
|
double | updateWeight(double[] gradients, double[] lastGradient, int index)
Update a weight.
|