本文適合有 Java 基礎的人群
作者:DJL-Lanking
HelloGitHub 推出的《講解開源項目》系列。有幸邀請到了亞馬遜 + Apache 的工程師:Lanking( https://github.com/lanking520 ),為我們講解 DJL —— 完全由 Java 構建的深度學習平台,本文為系列的第三篇。
一、前言
在 2018 年時,Google 推出了《猜畫小歌》應用:玩家可以直接與AI進行你畫我猜的游戲。通過畫出一個房子或者一個貓,AI 會推斷出各種物品被畫出的概率。它的實現得益於深度學習模型在其中的應用,通過深度神經網絡的歸納,曾經令人頭疼的繪畫識別也變得易如反掌。現如今,只要使用一個簡單的圖片分類模型,我們便可以輕松的實現繪畫識別。試試看這個在線塗鴉小游戲吧:
在當時,大部分機器學習計算任務仍舊需要依托網絡在雲端進行。隨着算力的不斷增進,機器學習任務已經可以直接在邊緣設備部署,包括各類運行安卓系統的智能手機。但是,由於安卓本身主要是用 Java ,部署基於 Python 的各類深度學習模型變成了一個難題。為了解決這個問題,AWS 開發並開源了 DeepJavaLibrary (DJL),一個為 Java 量身定制的深度學習框架。
在這個文章中,我們將嘗試通過 PyTorch 預訓練模型在在安卓平台構建一個塗鴉繪畫的應用。由於總代碼量會比較多,我們這次會挑重點把最關鍵的代碼完成。你可以后續參考我們完整的項目進行構建。
塗鴉應用完整代碼:https://github.com/aws-samples/djl-demo/tree/master/android
二、環境配置
為了兼容 DJL 需求的 Java 功能,這個項目需要 Android API 26 及以上的版本。你可以參考我們案例配置來節約一些時間,下面是這個項目需要的依賴項:
案例 gradle: https://github.com/aws-samples/djl-demo/blob/master/android/quickdraw_recognition/build.gradle
dependencies {
implementation 'androidx.appcompat:appcompat:1.2.0'
implementation 'ai.djl:api:0.7.0'
implementation 'ai.djl.android:core:0.7.0'
runtimeOnly 'ai.djl.pytorch:pytorch-engine:0.7.0'
runtimeOnly 'ai.djl.android:pytorch-native:0.7.0'
}
我們將使用 DJL 提供的 API 以及 PyTorch 包。
三、構建應用
3.1 第一步:創建 Layout
我們可以先創建一個 View class 以及 layout(如下圖)來構建安卓的前端顯示界面。
如上圖所示,你可以在主界面創建兩個 View
目標。PaintView
是用來讓用戶畫畫的,在右下角 ImageView
是用來展示用於深度學習推理的圖像。同時我們預留一個按鈕來進行畫板的清空操作。
3.2 第二步: 應對繪畫動作
在安卓設備上,你可以自定義安卓的觸摸事件響應來應對用戶的各種觸控操作。在我們的情況下,我們需要定義下面三種時間響應:
- touchStart:感應觸碰時觸發
- touchMove:當用戶在屏幕上移動手指時觸發
- touchUp:當用戶抬起手指時觸發
與此同時,我們用 paths 來存儲用戶在畫板所繪制的路徑。現在我們看一下實現代碼。
3.2.1 重寫 OnTouchEvent
和 OnDraw
方法
現在我們重寫 onTouchEvent
來應對各種響應:
@Override
public boolean onTouchEvent(MotionEvent event) {
float x = event.getX();
float y = event.getY();
switch (event.getAction()) {
case MotionEvent.ACTION_DOWN :
touchStart(x, y);
invalidate();
break;
case MotionEvent.ACTION_MOVE :
touchMove(x, y);
invalidate();
break;
case MotionEvent.ACTION_UP :
touchUp();
runInference();
invalidate();
break;
}
return true;
}
如上面代碼所示,你可以添加一個 runInference
方法在 MotionEvent.ACTION_UP
事件響應上。這個方法是用來在用戶繪制完后對結果進行推理。在之后的幾步中,我們會講解它的具體實現。
我們同樣需要重寫 onDraw
方法來展示用戶繪制的圖像:
@Override
protected void onDraw(Canvas canvas) {
canvas.save();
this.canvas.drawColor(DEFAULT_BG_COLOR);
for (Path path : paths) {
paint.setColor(DEFAULT_PAINT_COLOR);
paint.setStrokeWidth(BRUSH_SIZE);
this.canvas.drawPath(path, paint);
}
canvas.drawBitmap(bitmap, 0, 0, bitmapPaint);
canvas.restore();
}
真正的圖像會保存在一個 Bitmap
上。
3.2.2 操作開始(touchStart)
當用戶觸碰行為開始時,下面的代碼會建立一個新的路徑同時記錄路徑中每一個點在屏幕上的坐標。
private void touchStart(float x, float y) {
path = new Path();
paths.add(path);
path.reset();
path.moveTo(x, y);
this.x = x;
this.y = y;
}
3.2.3 手指移動(touchMove)
在手指移動中,我們會持續記錄坐標點然后將它們構成一個 quadratic bezier. 通過一定的誤差閥值來動態優化用戶的繪畫動作。只有差別超出誤差范圍內的動作才會被記錄下來。
quadratic bezier 文檔: https://developer.android.com/reference/android/graphics/Path
private void touchMove(float x, float y) {
if (x < 0 || x > getWidth() || y < 0 || y > getHeight()) {
return;
}
float dx = Math.abs(x - this.x);
float dy = Math.abs(y - this.y);
if (dx >= TOUCH_TOLERANCE || dy >= TOUCH_TOLERANCE) {
path.quadTo(this.x, this.y, (x + this.x) / 2, (y + this.y) / 2);
this.x = x;
this.y = y;
}
}
3.2.4 操作結束(touchUp)
當觸控操作結束后,下面的代碼會繪制一個路徑同時計算最小長方形目標框。
private void touchUp() {
path.lineTo(this.x, this.y);
maxBound.add(new Path(path));
}
3.3 第三步:開始推理
為了在安卓設備上進行推理任務,我們需要完成下面幾個任務:
- 從 URL 讀取模型
- 構建前處理和后處理過程
- 從 PaintView 進行推理任務
為了完成以下目標,我們嘗試構建一個 DoodleModel
class。在這一步,我們將介紹一些完成這些任務的關鍵步驟。
3.3.1 讀取模型
DJL 內建了一套模型管理系統。開發者可以自定義儲存模型的文件夾。
File dir = getFilesDir();
System.setProperty("DJL_CACHE_DIR", dir.getAbsolutePath());
通過更改 DJL_CACHE_DIR
屬性,模型會被存入相應路徑下。
下一步可以通過定義 Criteria 從指定 URL 處下載模型。下載的 zip 文件內包含:
doodle_mobilenet.pt
:PyTorch 模型synset.txt
:包含分類任務中所有類別的名稱
Criteria<Image, Classifications> criteria =
Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelUrls("https://djl-ai.s3.amazonaws.com/resources/demo/pytorch/doodle_mobilenet.zip")
.optTranslator(translator)
.build();
return ModelZoo.loadModel(criteria);
上述代碼同時定義了 translator,它會被用來做圖片的前處理和后處理。
最后,如下述代碼創建一個 Model
並用它創建一個 Predictor
:
@Override
protected Boolean doInBackground(Void... params) {
try {
model = DoodleModel.loadModel();
predictor = model.newPredictor();
return true;
} catch (IOException | ModelException e) {
Log.e("DoodleDraw", null, e);
}
return false;
}
更多關於模型加載的信息,請參閱如何加載模型。
DJL 模型加載文檔:http://docs.djl.ai/docs/load_model.html
3.3.2 用 Translator 定義前處理和后處理
在 DJL 中,我們定義了 Translator 接口進行前處理和后處理。在 DoodleModel 中我們定義了 ImageClassificationTranslator 來實現 Translator:
ImageClassificationTranslator.builder()
.addTransform(new ToTensor())
.optFlag(Image.Flag.GRAYSCALE)
.optApplySoftmax(true).build());
下面我們詳細闡述 translator 所定義的前處理和后處理如何被用在模型的推理步驟中。當你創建 translator 時,內部程序會自動加載 synset.txt
文件得到做分類任務時所有類別的名稱。當模型的 predict()
方法被調用時,內部程序會先執行所對應的 translator 的前處理步驟,而后執行實際推理步驟,最后執行 translator 的后處理步驟。對於前處理,我們會將 Image 轉化 NDArray,用於作為模型推理過程的輸入。對於后處理,我們對推理輸出的結果(NDArray)進行 softmax 操作。最終返回結果為 Classifications 的一個實例。
自定義 Translator 案例:http://docs.djl.ai/jupyter/pytorch/load_your_own_pytorch_bert.html
3.3.3 用 PaintView 進行推理任務
最后,我們來實現之前定義好的 runInference 方法。
public void runInference() {
// 拷貝圖像
Bitmap bmp = Bitmap.createBitmap(bitmap);
// 縮放圖像
bmp = Bitmap.createScaledBitmap(bmp, 64, 64, true);
// 執行推理任務
Classifications classifications = model.predict(bmp);
// 展示輸入的圖像
Bitmap present = Bitmap.createScaledBitmap(bmp, imageView.getWidth(), imageView.getHeight(), true);
imageView.setImageBitmap(present);
// 展示輸出的圖像
if (messageToast != null) {
messageToast.cancel();
}
messageToast = Toast.makeText(getContext(), classifications.toString(), Toast.LENGTH_SHORT);
messageToast.show();
}
這將會創建一個 Toast 彈出頁面用於展示結果,示例如下:
恭喜你!我們完成了一個塗鴉識別小程序!
3.4 可選優化:輸入裁剪
為了得到更高的模型推理准確度,你可以通過截取圖像來去除無意義的邊框部分。
上面右側的圖片會比左邊的圖片有更好的推理結果,因為它所包含的空白邊框更少。你可以通過 Bound 類來尋找圖片的有效邊界,即能把圖中所有白色像素點覆蓋的最小矩形。在得到 x 軸最左坐標,y 軸最上坐標,以及矩形高度和寬度后,就可以用這些信息截取出我們想要的圖形(如右圖所示)實現代碼如下:
RectF bound = maxBound.getBound();
int x = (int) bound.left;
int y = (int) bound.top;
int width = (int) Math.ceil(bound.width());
int height = (int) Math.ceil(bound.height());
// 截取部分圖像
Bitmap bmp = Bitmap.createBitmap(bitmap, x, y, width, height);
恭喜你!現在你就掌握了全部教程內容!期待看到你創建的第一個 DoodleDraw 安卓游戲!
最后,可以在GitHub找到本教程的完整案例代碼。
塗鴉應用完整代碼:https://github.com/aws-samples/djl-demo/tree/master/android
關於 DJL
Deep Java Library (DJL) 是一個基於 Java 的深度學習框架,同時支持訓練以及推理。 DJL 博取眾長,構建在多個深度學習框架之上 (TenserFlow、PyTorch、MXNet 等) 也同時具備多個框架的優良特性。你可以輕松使用 DJL 來進行訓練然后部署你的模型。
它同時擁有着強大的模型庫支持:只需一行便可以輕松讀取各種預訓練的模型。現在 DJL 的模型庫同時支持高達 70 個來自 GluonCV、 HuggingFace、TorchHub 以及 Keras 的模型。
關注 HelloGitHub 公眾號