Java程式員學深度學習 DJL上手5 訓練自己的模型
- 一、準備環境
- 二、創建示例專案
- 三、準備資料集
- 四、創建模型
- 五、創建訓練器
- 1. 訓練器配置
- 2. 初始化訓練器
- 3. 訓練模型
- 4. 保存模型
- 六、源代碼
- 1. pom
- 2. java
一、準備環境
- windows
- idea
- maven
二、創建示例專案
三、準備資料集
int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
mnist.prepare(new ProgressBar());
這里對資料集進行了分批處理,每批大小32,合適的分批大小將在訓練時顯著提升性能,
四、創建模型
本節會根據之前文章創建模型,由于 MNIST 資料集中的影像為 28x28 灰度影像,這里我們創建一個具有 28 x 28 輸入的 MLP 塊,
輸出的圖輸出為 10,因為每個影像可能有 10 個可能的類(0 到 9),
對于隱藏的層,其大小是猜測的值new int[] {128, 64}
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
五、創建訓練器
1. 訓練器配置
- 損失函式,用來測量模型與測驗資料集的匹配程度,值越低越好;這里定義為
softmaxCrossEntropyLoss() - 評估函式,也用于測量模型與資料集的匹配程度,與損失不同,它們只供人們查看,不用于優化模型,
- 監聽器,用來監控訓練程序,
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
Trainer trainer = model.newTrainer(config);
2. 初始化訓練器
這里使用輸入的形狀來初始化訓練器,初始化函式里形狀的第一個引數是批次大小,這個不影響引數初始化,
第二個引數是輸入影像的像素數,即28*28,
trainer.initialize(new Shape(1, 28 * 28));
3. 訓練模型
這里使用了DJL的EasyTrain,
int epoch = 2;
EasyTrain.fit(trainer, epoch, mnist, null);
4. 保存模型
保存模型還可以添加一些元資料,如訓練迭代次數、訓練精度等,
Path modelDir = Paths.get("build/mlp");
Files.createDirectories(modelDir);
model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "mlp");
System.out.println(model);
六、源代碼
1. pom
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.xundh</groupId>
<artifactId>djl-learning</artifactId>
<version>0.1-SNAPSHOT</version>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.version>8</java.version>
<djl.version>0.13.0-SNAPSHOT</djl.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>1.7.0</version>
</dependency>
</dependencies>
</project>
2. java
package com.xundh;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
public class NDArrayLearning {
public static void main(String[] args) throws IOException, TranslateException {
int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
mnist.prepare(new ProgressBar());
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[]{128, 64}));
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(1, 28 * 28));
int epoch = 2;
EasyTrain.fit(trainer, epoch, mnist, null);
Path modelDir = Paths.get("build/mlp");
Files.createDirectories(modelDir);
model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "mlp");
System.out.println(model);
}
}
運行結果示例:

轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/301694.html
標籤:其他
上一篇:pmp專案管理程序組和知識領域
