tensorflow學習筆記2:c++程序靜態鏈接tensorflow庫加載模型文件


首先需要搞定tensorflow c++庫,搜了一遍沒有找到現成的包,於是下載tensorflow的源碼開始編譯;

tensorflow的contrib中有一個makefile項目,極大的簡化的接下來的工作;

按照tensorflow makefile的說明文檔,開始做c++庫的編譯:

 

1. 下載依賴

在tensorflow的項目頂層運行:

tensorflow/contrib/makefile/download_dependencies.sh

東西會下載到tensorflow/contrib/makefile/downloads/目錄里;

 

2. 在linux下進行編譯

首先確保編譯工具都已經裝好了:

sudo apt-get install autoconf automake libtool curl make g++ unzip zlib1g-dev git python

然后運行編譯腳本;

注意:運行之前打開看一眼,第一步竟然是把tensorflow/contrib/makefile/downloads/目錄里的東西清空然后重新下載。。。注掉注掉

tensorflow/contrib/makefile/build_all_linux.sh

然后在tensorflow/contrib/makefile/gen/lib/libtensorflow-core.a就看到靜態庫了;

 

3. 准備好加載模型的c++代碼

#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"

using namespace tensorflow;

int main(int argc, char* argv[]) {
  // Initialize a tensorflow session
  Session* session;
  Status status = NewSession(SessionOptions(), &session);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Read in the protobuf graph we exported
  // (The path seems to be relative to the cwd. Keep this in mind
  // when using `bazel run` since the cwd isn't where you call
  // `bazel run` but from inside a temp folder.)
  GraphDef graph_def;
  status = ReadBinaryProto(Env::Default(), "models/test_graph.pb", &graph_def);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Add the graph to the session
  status = session->Create(graph_def);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Setup inputs and outputs:

  // Our graph doesn't require any inputs, since it specifies default values,
  // but we'll change an input to demonstrate.
  Tensor a(DT_FLOAT, TensorShape());
  a.scalar<float>()() = 3.0;

  Tensor b(DT_FLOAT, TensorShape());
  b.scalar<float>()() = 2.0;

  Tensor x(DT_FLOAT,TensorShape());
  x.scalar<float>()() = 10.0;

  std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
    { "a", a },
    { "b", b },
    { "x", x },
  };

  // The session will initialize the outputs
  std::vector<tensorflow::Tensor> outputs;

  // Run the session, evaluating our "y" operation from the graph
  status = session->Run(inputs, {"y"}, {}, &outputs);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

 // Grab the first output (we only evaluated one graph node: "c")
  // and convert the node to a scalar representation.
  auto output_y = outputs[0].scalar<float>();

  // (There are similar methods for vectors and matrices here:
  // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h)

  // Print the results
  std::cout << outputs[0].DebugString() << "\n"; // Tensor<type: float shape: [] values: 32>
  std::cout << output_y() << "\n"; // 32

  // Free any resources used by the session
  session->Close();
  return 0;
}

保存成load_graph.cc;

 

寫Makefile:

TARGET_NAME := load_graph

TENSORFLOW_MAKEFILE_DIR := /mnt/data/tensorflow/tensorflow/contrib/makefile

INCLUDES := \
-I /usr/local/lib/python3.6/dist-packages/tensorflow/include

NSYNC_LIB := \
$(TENSORFLOW_MAKEFILE_DIR)/downloads/nsync/builds/default.linux.c++11/nsync.a

PROTOBUF_LIB := \
$(TENSORFLOW_MAKEFILE_DIR)/gen/protobuf/lib/libprotobuf.a

TENSORFLOW_CORE_LIB := \
-Wl,--whole-archive $(TENSORFLOW_MAKEFILE_DIR)/gen/lib/libtensorflow-core.a -Wl,--no-whole-archive

LIBS := \
$(TENSORFLOW_CORE_LIB) \
$(NSYNC_LIB) \
$(PROTOBUF_LIB) \
-lpthread \
-ldl

SOURCES := \
load_graph.cc

$(TARGET_NAME):
	g++ -std=c++11 $(SOURCES) $(INCLUDES) -o $(TARGET_NAME) $(LIBS)

clean:
	rm $(TARGET_NAME)

這里的tensorflow-core、nsync和protobuf全都用靜態鏈接了,這些靜態庫以后考慮都放一份到系統目錄下;

 

有幾個點需要注意:

1) INCLUDE使用了python3.6的帶的tensorflow頭文件,只是覺得反正python都已經帶頭文件了,就不需要再另外拷一份頭文件進系統目錄了;

2) nsync庫是多平台的,因而可能需要仔細分析一下nsync的編譯結果所在位置,尤其如果是交叉編譯的話;

3) 鏈接順序不能錯,tensorflow-core肯定要在其它兩個前面;

4) tensorflow_core庫需要全鏈接進來,否則會出現這個錯:tensorflow/core/common_runtime/session.cc:69] Not found: No session factory registered for the given session options: {target: "" config: } Registered factories are {}.

    想想也大概能知道為什么,肯定是在靜態代碼層面只依賴父類,然后再在運行時通過名字找子類,所以在符號層面是不直接依賴子類的,不強制whole-archive的話,子類一個都帶不進來;

 

4. 運行程序

運行前先看看事先准備好的graph在不在預定位置,生成graph的方法見上一篇;

運行一下,沒啥好說的,結果正確。

 

參考:

http://blog.163.com/wujiaxing009@126/blog/static/7198839920174125748893/

https://blog.csdn.net/xinchen1234/article/details/78750079


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM