pytorch版本:1.6.0
pytorch-android版本:1.6.0
1 model.pt->model-script.pt
若模型上一次由GPU訓練得到,需要轉換成CPU形式
import torch
device = torch.device('cpu')
net=torch.load('model.pt', map_location = device)
torch.save(net,'model-cpu.pt')
然后把model.pt轉換成Pytorch-script,以便在安卓上運行
import torch
# 如果網絡使用的是class Net的定義法,而不是快速搭建法,需要在此處引入class Net的定義
model = torch.load("model-cpu.pt")
model.eval()
input_tensor = torch.rand(1,100) # 這里寫你的模型的輸入張量形狀
# 筆者的模型為輸入一維張量,100 features,后文的輸入尺寸也與此對應
script_model = torch.jit.trace(model,input_tensor)
script_model.save("model-script.pt")
2 Android Studio 配置
新建一個 C++ Native 項目,選擇c++11
在 build.gradle (Module) repositories 下,注釋掉 jcenter(),添加鏡像源:
maven{ url 'http://maven.aliyun.com/nexus/content/repositories/central/'}
maven{ url 'http://maven.aliyun.com/nexus/content/repositories/jcenter'}
在 build.gradle (App) dependencies 下,添加依賴:
implementation 'org.pytorch:pytorch_android:1.6.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.6.0'
添加后,Android Studio 提示同步,點擊 Sync Now,開始同步Gradle
在main目錄下新建assets目錄,把之前轉換好的model-script.pt放到該目錄下
3 調用模型
在 Activity 中,添加:
float[] data=new float[100];
//do something,為data賦值,可以是從文件加載、從用戶輸入、從相機獲取,等等
copyAssetAndWrite("model-script.pt");//把模型從assets寫入緩存,以便調用
Module module = Module.load(getCacheDir()+"/model-script.pt");//從緩存區加載模型
long shape[]={1,100};//模型輸入形狀
Tensor tensor=Tensor.fromBlob(data,shape);//tensor初始化方法
IValue input=IValue.from(tensor);
Tensor output=module.forward(input).toTensor();
float predict[]=output.getDataAsFloatArray();
這樣,predict[]就是從網絡得到的輸出了