生產與學術
寫於 2019-01-08 的舊文, 當時是針對一個比賽的探索. 覺得可能對其他人有用, 就放出來分享一下
生產與學術, 真實的對立...
這是我這兩天對pytorch深度學習->android實際使用
的這個流程的一個切身感受.
說句實在的, 對於模型轉換的探索, 算是我這兩天最大的收獲了...
全部濃縮在了這里: https://github.com/lartpang/DHSNet-PyTorch/blob/master/converter.ipynb
鑒於github加載ipynb太慢, 這里可以使用這個鏈接 https://nbviewer.jupyter.org/github/lartpang/DHSNet-PyTorch/blob/master/converter.ipynb
這兩天
最近在研究將pytorch的模型轉換為獨立的app, 網上尋找, 找到了一個流程: pytorch->onnx->caffe2->android apk. 主要是基於這篇文章的啟發: caffe2&pytorch之在移動端部署深度學習模型(全過程!).
這兩天就在折騰這個工具鏈,為了導出onnx的模型, 不確定要基於怎樣的網絡, 是已經訓練好的, 還是原始搭建網絡后再訓練來作為基礎. 所以不斷地翻閱pytorch和onnx的官方示例, 想要研究出來點什么, 可是, 都是自己手動搭建的模型. 而且使用的是預訓練權重, 不是這樣:
def squeezenet1_1(pretrained=False, **kwargs):
r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
than SqueezeNet 1.0, without sacrificing accuracy.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = SqueezeNet(version=1.1, **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1']))
return model
# Get pretrained squeezenet model
torch_model = squeezenet1_1(True)
from torch.autograd import Variable
batch_size = 1 # just a random number
# Input to the model
x = Variable(torch.randn(batch_size, 3, 224, 224), requires_grad=True)
# Export the model
torch_out = torch.onnx._export(
torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
"squeezenet.onnx", # where to save the model (can be a file or file-like object)
export_params=True) # store the trained parameter weights inside the model file
就是這樣:
# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)
# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1 # just a random number
# Initialize model with the pretrained weights
torch_model.load_state_dict(model_zoo.load_url(model_url))
# set the train mode to false since we will only run the forward pass.
torch_model.train(False)
兩種都在載入預訓練權重, 直接加載到搭建好的網絡上. 對於我手頭有的已經訓練好的模型, 似乎並不符合這樣的條件.
導出整體模型
最后采用盡可能模仿上面的例子代碼的策略, 將整個網絡完整的導出(torch.save(model)
), 然后再仿照上面那樣, 將完整的網絡加載(torch.load()
)到轉換的代碼中, 照貓畫虎, 以進一步處理.
這里也很大程度上受到這里的啟發: https://github.com/akirasosa/mobile-semantic-segmentation
本來想嘗試使用之前找到的不論效果還是性能都很強的R3Net進行轉換, 可是, 出於作者搭建網絡使用的特殊手段, 加上pickle和onnx的限制, 這個嘗試沒有奏效, 只好轉回頭使用之前學習的DHS-Net的代碼, 因為它的實現是基於VGG的, 里面的搭建的網絡也是需要修改來符合onnx的要求, 主要是更改上采樣操作為轉置卷積(也就是分數步長卷積, 這里順帶溫習了下pytorch里的nn.ConvTranspose2d()
的計算方式), 因為pytorch的上采樣在onnx轉換過程中有很多的問題, 特別麻煩, 外加上修改最大池化的一個參數(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False)
的參數ceil_mode
改為ceil_mode=False
, 這里參考自前面的知乎專欄的那篇文章), 這樣終於可以轉換了, 為了方便和快速的測試, 我只是訓練了一個epoch, 就直接導出模型, 這次終於可以順利的torch.save()
了.
filename_opti = ('%s/model-best.pth' % check_root_model)
torch.save(model, filename_opti)
之后便利用類似的代碼進行了書寫.
IMG_SIZE = 224
TMP_ONNX = 'cache/onnx/DHSNet.onnx'
MODEL_PATH = 'cache/opti/total-opti-current.pth'
# Convert to ONNX once
model = torch.load(MODEL_PATH).cuda()
model.train(False)
x = Variable(torch.randn(1, 3, 224, 224), requires_grad=True).cuda()
torch_out = torch.onnx._export(model, x, TMP_ONNX, export_params=True)
caffe2模型轉換
載入模型后, 便可以開始轉換了, 這里需要安裝caffe2, 官方推薦直接conda安裝pytorch1每夜版即可, 會自動安裝好依賴.
說起來這個conda, 就讓我又愛又恨, 用它裝pytorch從這里可以看出來, 確實不錯, 對系統自身的環境沒有太多的破壞, 可是用它裝tensorflow-gpu的時候, 卻是要自動把conda源里的cuda, cudnn工具包都給帶上, 有時候似乎會破壞掉系統自身裝載的cuda環境(? 不太肯定, 反正現在我不這樣裝, 直接上pip裝, 干凈又快速).
之后的代碼中, 主要的問題也就是tensor的cpu/cuda, 或者numpy的轉換的問題了. 多嘗試一下, 輸出下類型就可以看到了.
# Let's also save the init_net and predict_net to a file that we will later use for running them on mobile
with open('./cache/model_mobile/init_net.pb', "wb") as fopen:
fopen.write(init_net.SerializeToString())
with open('./cache/model_mobile/predict_net.pb', "wb") as fopen:
fopen.write(predict_net.SerializeToString())
預處理的補充
這里記錄下, 查看pytorch的tensor的形狀使用tensor.size()
方法, 查看numpy數組的形狀則使用numpy數組的adarray.shape
方法, 而對於PIL(from PIL import Image
)讀取的Image對象而言, 使用Image.size
查看, 而且, 這里只會顯示寬和高的長度, 而且Image的對象, 是三維, 在於pytorch的tensor轉換的時候, 或者輸入網絡的時候, 要注意添加維度, 而且要調整通道位置(img = img.transpose(2, 0, 1)
).
由於網絡保存的部分中, 只涉及到了網絡的結構內的部分, 對於數據的預處理的部分並不涉及, 所以說要想真正的利用網絡, 還得調整真實的輸入, 來作為更適合網絡的數據輸入.
要注意, 這里針對導出的模型的相關測試, 程實際上是按照測試網絡的流程來的.
# load the resized image and convert it to Ybr format
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img = Image.open("./data/ILSVRC2012_test_00000004_224x224.jpg")
img = np.array(img)
img = img.astype(np.float64) / 255
img -= mean
img /= std
img = img.transpose(2, 0, 1)
安卓的嘗試
首先安卓環境的配置就折騰了好久, 一堆破事, 真實的生產開發, 真心不易啊...
這里最終還是失敗了, 因為對於安卓的代碼是在是不熟悉, 最起碼的基礎認知都不足, 只有這先前學習Java的一點皮毛知識, 根本不足以二次開發. 也就跑了跑幾個完整的demo而已.
AiCamera
這個跑通了, 但是這是個分類網絡的例子, 對於我們要做的分割的任務而言, 有很多細節不一樣.
- 輸入有差異: 比賽要求的是若是提交apk, 那么要求可以從相冊讀取圖片, 而例子是從攝像頭讀取的視頻數據流. 雖然也處理的是視頻幀, 但是要我們再次補充的內容又多了起來, 還是那句話, android一竅不通.
- 輸出有差異: 自我猜測, 比賽為了測評, 輸出必然也要輸出到相冊里, 不然何來測評一說?
AICamera-Style-Transfer
這個例子我們參考了一下, 只是因為它的任務是對攝像頭視頻流數據風格遷移, 而且會直接回顯到手機屏幕上, 這里我們主要是想初步實現對於我們網絡模型安卓遷移的測試, 在第一個例子的基礎上能否實現初步的攝像頭視頻流的分割, 然后下一步再進一步滿足比賽要求.
可是, 嘗試失敗了. 雖然AS打包成了APK, 手機也安裝上了, 可是莫名的, 在"loading..."中便閃退了...
JejuNet
這個例子很給力, 但是使用的是tensorflowlite, 雖然可以用, 能夠實現下面的效果, 可是, 不會改.
而且是量化網絡, 准確率還是有待提升.
最后的思考
最后還是要思考一下的, 做個總結.
沒經驗
吃就吃在沒經驗的虧上了, 都是初次接觸, 之前沒怎么接觸過安卓, 主要是安卓的開發對於電腦的配置要求太高了, 自己的筆記本根本不夠玩的. 也就沒有接觸過了.
外加上之前的研究學習, 主要是在學術的環境下搞得, 和實際的生產還有很大的距離, 科研與生產的分離, 這對於深度學習這一實際上更偏重實踐的領域來說, 有些時候是尤為致命的. 關鍵時刻下不去手, 這多么無奈, 科學技術無法轉化為實實在在的生產力, 忽然有些如夢一般的縹緲.
當然, 最關鍵的還是, 沒有仔細分析賽方的需求, 沒有完全思考清楚, 直接就開干了, 這個魯莽的毛病, 還是沒有改掉, 浪費時間不說, 也無助於實際的進度. 賽方的說明含糊, 應該問清楚.
若是擔心時間, 那更應該看清楚要求, 切莫隨意下手. 比賽說明里只是說要提交一個打包好的應用, 把環境, 依賴什么都處理好, 但是不一定是安卓apk呀, 可以有很多的形式, 但是這也只是最后的一點額外的輔助而已, 重點是模型的性能和效率呢.
莫忘初心, 方得始終. 為什么我想到的是這句.
下一步
基本上就定了還是使用R3Net, 只能是進一步的細節修改了, 換換后面的循環結構了, 改改連接什么的.
我准備再開始看論文, 學姐的論文可以看看, 似乎提出了一種很不錯的后處理的方法, 效果提升很明顯, 需要研究下.
pickle和onnx的限制
pytorch的torch.save(model)
保存模型的時候, 模型架構的代碼里不能使用一些特殊的構建形式, R3Net的ResNeXt結構就用了, 主要是一些lambda結構, 雖然不是太清楚, 但是一般的搭建手段都是可以的.
onnx對於pytorch的支持的操作, 在我的轉化中, 主要是最大池化和上采樣的問題, 前者可以修改ceil_mode
為False
, 后者則建議修改為轉置卷積, 避免不必要的麻煩. 可見"導出整體模型"小節的描述.
打包apk安裝
這里主要是用release版本構建的apk.
未簽名的apk在我的mi 8se (android 8.1)上不能安裝, 會解析失敗, 需要簽名, AS的簽名的生成也很簡單, 和生成apk在同一級上, 有生成的選項.