tensorflow源碼解析之framework-tensor


目錄

  1. 什么是tensor
  2. tensor繼承體系
  3. 與Eigen3庫的關系
  4. 什么是tensor_reference
  5. tensor_shape
  6. tensor_slice
  7. 其它結構
  8. 關系圖
  9. 涉及的文件
  10. 迭代記錄

1. 什么是tensor

TF全稱叫做TensorFlow,可見tensor的重要性。它本質上是一個對高維數據的封裝,提供了豐富的API。在線性代數中,我們常用向量、矩陣來表示數據,而在深度學習應用中,有對更高維數據的需求。比如在對圖像進行處理時,彩色圖像本身就帶有三維的信息(長、寬、顏色通道),通常還需要對彩色圖像進行批處理,這樣待處理的數據變為四維,在一些特殊的情形下,往往還需要更高維度的數據。如果針對每種多維數據定義一種結構,必然給計算帶來不便。TF的做法是,為高維數據定義統一的類型Tensor。

但高維數據的概念有點抽象,為了讓大家能對Tensor內部的數據結構有個直觀的印像,我們先看一下Tensor類的私有數據成員:

class Tensor {
    //...
  private:
    TensorShape shape_;
    TensorBuffer* buffer_;
}

這兩個結構都沒有見過,不過沒關系,只把它們當做張量的形狀和底層數據指針就好了。Tensor作為一個核心數據類,必然提供了很多API,比如常規的構造、析構、賦值、復制、數值屬性獲取等。除此之外,還提供了兩類比較特殊的接口,我們舉例說明:

class Tensor {
  public:
    //...
    //與proto數據的相互轉化
    bool FromProto(const TensorProto& other);
    void AsProtoField(TensorProto* proto);
    //為底層數據創建新視圖
    template <typename T> typename TTypes<T>::Vec vec();
    template <typename T> typename TTypes<T>::Matrix matrix();
    template <typename T> typename TTypes<T, NDIMS>::Tensor tensor();
}

其中第一類將Tensor與序列化的proto之間相互轉化,便於在設備之間傳遞Tensor。第二類是為當前的Tensor的底層數據提供另外一種視圖,我們重點來說一下視圖的概念。

回顧Tensor包含的私有數據,TensorBuffer* buffer_是一個指向底層數據的指針,關於它的結構在下文中會詳細說明。這意味着,Tensor並不包含實際的底層數據,它實際上只是對底層數據的一種視圖。同樣一份底層數據,可以提供多種視圖。比如對於一個長度為12的數組,如果把它看做向量,它是一個1x12的向量,如果把它看作矩陣,可以認為是3x4或者2x6的矩陣,如果把它當作張量,可以認為是3x2x2的張量。通過這種方法,我們可以對同一份底層數據進行復用,避免了重復申請內存空間,提升了效率。

graph TB A("Tensor A, shape=[3,4]")-->D(底層數據TensorBuffer) B("Tensor B, shape=[2,6]")-->D(底層數據TensorBuffer) C("Tensor C, shape=[3,2,2]")-->D(底層數據TensorBuffer)

順便提一句,numpy中對多維數組的實現,也是同樣的原理。

細心的讀者可能發現了,在對底層數據創建新視圖時,返回了一種奇怪的數據類型typename TTypes<T>::Vec,這涉及TF中的Tensor與Eigen3庫的關系,我們將在下文中詳細說明。

2. tensor繼承體系

接下來我們看一下TensorBuffer到底是什么樣的結構。它只是一個繼承自引用計數類的虛擬接口,不包含任何實現:

class TensorBuffer : public core::RefCounted {
    //...
}

因此懷疑,TensorBuffer只是一個提供接口的基類,實際上能用的只是它的子類。我們看下它的繼承結構:

class BufferBase : public TensorBuffer {
    //...
}
class Buffer : public BufferBase {
    //...
  private:
    T* data_;
    int64 elem_;
}

結構已經非常清晰了,BufferBase類繼承自TensorBuffer,它除了包含一個內存分配器指針外,還對基類中的部分API進行了實現。而Buffer類是實際可用的,它包含了指向實際數據的指針data_以及元素數量elem_。

另外還要說明一點,Buffer除了申請內存之外,還能調用目標類的構造和析構函數,初始化Buffer的內容,TF為此設計了很多輔助類和函數,這里就不一一贅述了。

Tensor的繼承體系圖如下:

graph TB A(core::RefCounted)-->|派生|B(TensorBuffer) B(TensorBuffer)-->|派生|C(BufferBase) C(BufferBase)-->|派生|D(Buffer) C(BufferBase)-.包含.->E(Allocator* alloc_) D(Buffer)-.包含.->F(T* data_) D(Buffer)-.包含.->G(int64 elem_) H(Tensor)-.包含.->B(TensorBuffer)

3. 與Eigen3庫的關系

剛才提到了,當為Tensor的數據提供不同視圖的時候,返回了一種奇怪的數據TTypes<T>::Vec,這種數據為TF中的Tensor和Eigen3庫中的Tensor建立了聯系。我們在tensor_types.h文件中,找到了這種類型的定義:

struct TTypes {
    typedef Eigen::TensorMap<Eigen::Tensor<T,NDIMS,Eigen::RowMajor,IndexType>,Eigen::Aligned> Tensor;
    typedef Eigen::TensorMap<Eigen::Tensor<T,1,Eigen::RowMajor,IndexType>,Eigen::Aligned> Vec;
    //...
}

原來,對Eigen3庫中Tensor的使用在這里。由於這種定義被包裹在TTypes結構體中,所以不會與外部TF自定義的Tensor造成沖突。

重新回到Tensor的定義,我們發現,原來在對Tensor底層數據提供多種視圖的時候,返回的已經不是Tensor結構,而是TTypes::TensorMap,這是否意味着,TF中定義的Tensor只是對Eigen::Tensor的一種封裝呢?我們追根溯源,找到vec函數的實現:

template <typename T>
typename TTypes<T>::Vec vec() {
    return tensor<T,1>();
}

template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
    CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
    return typename TTypes<T, NDIMS>::Tensor(base<T>(), shape().AsEigenDSizes<NDIMS>());
}

跟我們預想的完全一樣,在對vec函數的調用中,調用了tensor函數,而這個函數的作用,就是將TF中定義的Tensor轉變為TTypes::Tensor,而后者就是Eigen::TensorMap,也就是說,tensor返回的本質上是一個Eigen::TensorMap。另外,我們知道base()和shape()兩個函數,分別返回了TensorBuffer指針和TensorShape,因此實際上就是使用TF中Tensor存儲的數據,作為了Eigen::TensorMap的構造函數的參數。

可以說,TF中的Tensor實際上是對Eigen::TensorMap的一種高級封裝,它不是簡單的在私有數據成員包含后者,而是包含了構造后者所需要的數據,在需要后者的時候,構造並返回。這種方式,使得TF中的Tensor既能利用Eigen高效的張量計算方法,也能為Tensor定制一些API。

4. 什么是tensor_reference

Tensor類的對象除了包含指向底層數據的指針外,還包含了對數據形狀和類型的描述(通過TensorShape),如果我們並不關心這些,直接使用Tensor會增加構建或者移動的負擔。因此TF推出了tensor_reference這個類,它僅包含了一個指向TensorBuffer的指針,並且每增加一個TensorReference對象,就會增加一個針對底層TensorBuffer的引用計數。因此針對TensorReference來說,我們唯一能做的就是在用完之后Unref掉,否則會造成內存泄漏。

class TensorReference {
  public:
    //...
  private:
    TensorBuffer* buf_;
}

5. tensor_shape

TensorShape顯然包含的是張量形狀相關的信息,但其實不僅如此,它還包含了對張量數據類型的描述。TensorShape相關的核心類繼承體系如下:

graph LR I(TensorShapeRep)-->|派生|J(TensorShapeBase) J(TensorShapeBase)-->|派生|K(TensorShape) J(TensorShapeBase)-->|派生|L(PartialTensorShape)

首先來看一下,最底層的TensorShapeRep的私有數據成員:

class TensorShapeRep {
    //...
  private:
    union {
        uint8 buf[16];
        Rep64* unused_aligner;//除了強制u_與指針對齊外,沒有任何作用
    } u_;
    int64 num_elements_;
}

buf這個數組很有意思,它的前12個元素用來存儲形狀,雖然Tensor最高能支持到256維的張量,但最常用的不超過3維,為了效率,TF提供了三種利用這12個字節的方式,如下:

struct Rep16 {
    uint16 dims_[6];//最多可表示6維的張量,每一維的長度不超過2^16-1
};
struct Rep32 {
    uint32 dims_[3];//最多可表示3維的張量,每一維的長度不超過2^32-1
};
struct Rep64 {
    gtl::InlinedVector<int64, 4>* dims_;//支持任意維度的張量
};

剩下的4個字節也不能浪費,在第14-16個字節中,分別存儲了張量中的數據類型編號、張量的維度數目、張量維度的表示類型(Rep16, Rep32, Rep64)。由於張量維度的數目是用一個字節存儲的,因此最多支持256維。可惜筆者目前仍沒有發現第13個字節的作用,有發現的讀者歡迎告知我。

TensorShapeBase類並沒有添加額外的數據成員,它只是添加了一些允許我們修改張量維度的API接口。而TensorShape類也只是添加了一些對形狀進行檢查和比較的接口,沒有新增數據成員。

最后再來看下PartialTensorShape類,在構造一個張量的形狀時,如果對於某些維度我們還不知道具體的維度值,可以把這個維度設為未知,因此就會用到PartialTensorShape類,這個類中也包含了一些未知維度操作的API,這里就不詳述了。

6. tensor_slice

TensorSlice類表示一個張量的索引,它的數據結構非常簡單:

class TensorSlice {
    //...
  private:
    gtl::InlinedVector<int64,4> starts_;
    gtl::InlinedVector<int64,4> lengths_;
}

分別是每一個維度索引的開始位置和索引長度,由此我們也知道,TF對Tensor只支持連續索引,不支持間隔索引。
由於TensorSlice用途廣泛,對其進行初始化的方法也多種多樣,包括:

  • 創建空索引
  • 從單個維度創建(當創建全索引時)
  • 從一個整數對數組創建
  • 從一個TensorSliceProto創建
  • 從一個字符串描述中創建

7. 其它結構

為了方便對張量和與之相關的數據結構進行序列化,TF設計了很多protos,理解起來相對簡單,現只說明下它們的用途,感興趣的讀者可以去看源代碼。

message TensorDescription;//張量的描述,包括數據類型、形狀、內存分配信息
message TensorProto;//張量的數據類型,版本,原始數據等
message VariantTensorDataProto;//對DT_VARIANT類型的序列化表示
message TensorShapeProto;//張量形狀
message TensorSliceProto;//張量索引

8. 關系圖

graph TB A(core::RefCounted)-->|派生|B(TensorBuffer) B(TensorBuffer)-->|派生|C(BufferBase) C(BufferBase)-->|派生|D(Buffer) C(BufferBase)-.包含.->E(Allocator* alloc_) D(Buffer)-.包含.->F(T* data_) D(Buffer)-.包含.->G(int64 elem_) H(Tensor)-.包含.->B(TensorBuffer) I(TensorShapeRep)-->|派生|J(TensorShapeBase) J(TensorShapeBase)-->|派生|K(TensorShape) J(TensorShapeBase)-->|派生|L(PartialTensorShape) H(Tensor)-.索引結構.->M(TensorSlice) H(Tensor)-.形狀和數據類型描述.->K(TensorShape) H(Tensor)-.轉換為.->N(TTypes::Tensor) N(TTypes::Tensor)-.轉換為.->O(Eigen::TensorMap)

9. 涉及的文件

  • tensor
  • tensor_reference
  • tensor_types
  • tensor_shape
  • tensor_slice
  • tensor_description

10. 迭代記錄

  • v1.0 2018-08-26 文檔創建
  • v2.0 2018-09-09 文檔重構


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM