这是一个djl训练的简单模板
import java.io.IOException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.FashionMnist;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.Parameter;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.initializer.NormalInitializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.TranslateException;
public class MulClassMain {
public static void main(String[] args) throws IOException, TranslateException {
// TODO 1. 定义模型结构
SequentialBlock net = new SequentialBlock();
net.add(Blocks.batchFlattenBlock(784));
net.add(Linear.builder().setUnits(256).build());
net.add(Activation::relu);
net.add(Linear.builder().setUnits(10).build());
net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);
// 训练过程
int batchSize = 256;
int numEpochs = Integer.getInteger("MAX_EPOCH", 20);
FashionMnist trainIter = FashionMnist.builder()
.optUsage(Dataset.Usage.TRAIN)
.setSampling(batchSize, true)
.optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
.build();
FashionMnist testIter = FashionMnist.builder()
.optUsage(Dataset.Usage.TEST)
.setSampling(batchSize, true)
.optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
.build();
trainIter.prepare();
testIter.prepare();
// 定义优化算法
Tracker lrt = Tracker.fixed(0.5f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
// 定义损失函数
Loss loss = Loss.softmaxCrossEntropyLoss();
DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
.optOptimizer(sgd)
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
try(NDManager nm = NDManager.newBaseManager()){
try(Model model = Model.newInstance("mlp")){
model.setBlock(net);
try(Trainer trainer = model.newTrainer(config)){
trainer.initialize(new Shape(1, 784));
trainer.setMetrics(new Metrics());
for(int epoch=0;epoch<numEpochs;++epoch){
System.out.printf("Epoch %d \n", epoch);
for(Batch batch: trainIter.getData(nm)){
EasyTrain.trainBatch(trainer, batch);
// 更新参数
trainer.step();
batch.close();
}
trainer.notifyListeners(l->l.onEpoch(trainer));
}
}
}
}
}
public NDArray relu(NDArray X){
return X.maximum(0.0f);
}
}
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.