之前參見了中國軟件杯大賽,在大賽中用到了深度學習的相關算法,也訓練了一些簡單的模型。項目線上平台是用java編寫的web應用程序,而深度學習使用的是python語言,這就涉及到了在java代碼中調用python語言的方法。
為了能在java應用中使用python語言訓練的算法模型,我在網上找了很久。我大概找到了三種方法
1. java代碼可以直接調用python代碼,只需要下載相應的jar包就行。這種方式我沒有嘗試,只是覺得這樣做使得java應用太過於依賴python的環境了。還有網上也有將python代碼打包成jar的方法,然后可以讓java代碼調用,但是很多第三方庫不能打包成jar包。
2. 將python訓練的模型參數保存到文本中,用java代碼重現模型的預測算法。我之前就這樣做過。這么做顯然工作量太大,而且出現的bug幾率大大增加。最重要的是很多深度學習的框架就沒辦法用了。
3. 使用python進程運行深度學習中訓練的模型,在java應用程序中調用python進程提供的服務。這種方法我認為是最好的。python語言寫得程序畢竟還是在python環境中執行最有效率。而且python應用和java應用可以運行在不同的服務器上,通過進程的遠程訪問調用。
以下是我實現java應用程序訪問python進程的python代碼部分。進程之間只能是通過socket進行通信。我本來想過用python編寫一個web應用,對java提供HTTP服務,后來覺得這樣還需要web服務器,對環境依賴太大,而且兩個進程間的通信也很簡單,所以干脆直接用socket進行調用得了
import socket import sys import threading import json import numpy as np from tag import train2 # nn=network.getNetWork() # cnn = conv.main(False) # 深度學習訓練的神經網絡,使用TensorFlow訓練的神經網絡模型,保存在文件中 nnservice = train2.NNService(model='model/20180731.ckpt-1000') def main(): # 創建服務器套接字 serversocket = socket.socket(socket.AF_INET,socket.SOCK_STREAM) # 獲取本地主機名稱 host = socket.gethostname() # 設置一個端口 port = 12345 # 將套接字與本地主機和端口綁定 serversocket.bind((host,port)) # 設置監聽最大連接數 serversocket.listen(5) # 獲取本地服務器的連接信息 myaddr = serversocket.getsockname() print("服務器地址:%s"%str(myaddr)) # 循環等待接受客戶端信息 while True: # 獲取一個客戶端連接 clientsocket,addr = serversocket.accept() print("連接地址:%s" % str(addr)) try: t = ServerThreading(clientsocket)#為每一個請求開啟一個處理線程 t.start() pass except Exception as identifier: print(identifier) pass pass serversocket.close() pass class ServerThreading(threading.Thread): # words = text2vec.load_lexicon() def __init__(self,clientsocket,recvsize=1024*1024,encoding="utf-8"): threading.Thread.__init__(self) self._socket = clientsocket self._recvsize = recvsize self._encoding = encoding pass def run(self): print("開啟線程.....") try: #接受數據 msg = '' while True: # 讀取recvsize個字節 rec = self._socket.recv(self._recvsize) # 解碼 msg += rec.decode(self._encoding) # 文本接受是否完畢,因為python socket不能自己判斷接收數據是否完畢, # 所以需要自定義協議標志數據接受完畢 if msg.strip().endswith('over'): msg=msg[:-4] break # 解析json格式的數據 re = json.loads(msg) # 調用神經網絡模型處理請求 res = nnservice.hand(re['content']) sendmsg = json.dumps(res) # 發送數據 self._socket.send(("%s"%sendmsg).encode(self._encoding)) pass except Exception as identifier: self._socket.send("500".encode(self._encoding)) print(identifier) pass finally: self._socket.close() print("任務結束.....") pass def __del__(self): pass if __name__ == "__main__": main()
在java代碼中訪問python進程的代碼:
private Object remoteCall(String content){ JSONObject jsonObject = new JSONObject(); jsonObject.put("content", content); String str = jsonObject.toJSONString(); // 訪問服務進程的套接字 Socket socket = null; List<Question> questions = new ArrayList<>(); log.info("調用遠程接口:host=>"+HOST+",port=>"+PORT); try { // 初始化套接字,設置訪問服務的主機和進程端口號,HOST是訪問python進程的主機名稱,可以是IP地址或者域名,PORT是python進程綁定的端口號 socket = new Socket(HOST,PORT); // 獲取輸出流對象 OutputStream os = socket.getOutputStream(); PrintStream out = new PrintStream(os); // 發送內容 out.print(str); // 告訴服務進程,內容發送完畢,可以開始處理 out.print("over"); // 獲取服務進程的輸入流 InputStream is = socket.getInputStream(); BufferedReader br = new BufferedReader(new InputStreamReader(is,"utf-8")); String tmp = null; StringBuilder sb = new StringBuilder(); // 讀取內容 while((tmp=br.readLine())!=null) sb.append(tmp).append('\n'); // 解析結果 JSONArray res = JSON.parseArray(sb.toString()); return res; } catch (IOException e) { e.printStackTrace(); } finally { try {if(socket!=null) socket.close();} catch (IOException e) {} log.info("遠程接口調用結束."); } return null; }