tensorflow添加新操作(Op)


想要為點雲加一個尋找k近鄰的操作,好像只能通過寫新Op實現,看了半天博客半懂不懂的,改改試試(對A-CNN里的ordering操作)

為了加入一個定制操作,你需要:

  1. c++ 文件中注冊一個新opOp registration 定義了 op 的功能接口,它和 op 的實現是獨立的。例如:op registration 定義了 op 的名字和 op的輸出輸出。它同時也定義了 shape 方法,被用於 tensorshape 接口。
  2. c++ 中實現 opop 的實現稱之為 kernel(cuda核函數) ,它是op 的一個具體實現。對於不同的輸入輸出類型或者 架構(CPUs,GPUs)可以有不同的 kernel 實現 。
  3. 創建一個 python wrapper(可選的): 這個 wrapper 是一個 公開的 API,用來在 python中創建 opop registration 會生成一個默認的 wrapper,我們可以直接使用或者自己添加一個。
  4. 寫一個計算 op 梯度的方法(可選)。
  5. 測試 op:為了方便,我們通常在 python 中測試 op,但是你也可以在 c++ 中進行測試。如果你定義了 gradients,你可以 通過 Pythongradient checker 驗證他們。 這里有個例子relu_op_test.py ,測試 ReLU-likeop 的 前向和梯度過程。

1. 定義接口(.cpp里):

在注冊 op 的時候,你需要指定:

  • op 的名字
  • op 的輸入(名字,類型),op 的輸出(名字,類型)
  • docstrings
  • op 可能需要的 一些 attrs
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include <cuda_runtime.h>

using namespace tensorflow;

REGISTER_OP("RingPoint")
    .Attr("radius_in: float")
    .Attr("radius_out: float")
    .Attr("nsample: int")
    .Input("xyz1: float32")
    .Input("xyz2: float32")
    .Input("idx2: int32")
    .Input("kernel: float32")
    .Output("idx: int32")
    .Output("pts_cnt: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
        ::tensorflow::shape_inference::ShapeHandle dims2;
        c->WithRank(c->input(1), 3, &dims2);       //把輸入維度取出來放到dim2里吧
        int nsample;
        TF_RETURN_IF_ERROR(c->GetAttr("nsample", &nsample));
        ::tensorflow::shape_inference::ShapeHandle output1 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1), nsample}); //和輸入維度不一樣就要自己設一下
       c->set_output(0, output1);                 //然后把自己設的輸出維度和第0個輸出對應
        ::tensorflow::shape_inference::ShapeHandle output2 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1)});
     //c->set_output(0, c->input(0));           //輸出維度和輸入維度一樣就用輸入的維度
        c->set_output(1, output2);         
        return Status::OK();
    });
View Code

 

關於命名的備注:操作名稱必須首字母大寫,而且不能和庫中已經注冊的其它操作重名。

2. 實現操作的內核(.cpp中)

定義接口后,接下來就需要為此操作提供一個或多個內核實現了。
為了實現這些內核,創建一個繼承自 OpKernel 的類,並重載 Compute 方法。
Compute 方法有一個類型為 OpKernelContext* 的參數 context,從中可以訪問輸入和輸出張量等有用的信息。

  • CPU版本:
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // 得到輸入張量
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // 創建輸出張量
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // 除第一個元素外,輸出張量的其它所有元素都設置為 0 
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // 如果可能的話,保留第一個輸入值
    if (N > 0) output_flat(0) = input(0);
  }
};
View Code

 

ZeroOut 操作加上約束條件:

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

這里注冊的操作名是ZeroOut,通過上面的語句和ZeroOutOp對應吧

  • GPU版本:重寫的compute里一般都只是數據的處理:大小推斷、flat成一維、分配內存、設初值等等,具體的計算用cuda的核函數實現
//先聲明
void ringPointLauncher(int b, int n, int m, float radius_in, float radius_out, int nsample, const float *xyz1, const float *xyz2, const int * idx2, int *idx, int *pts_cnt);
//再繼承OpKernel,OP_REQUIRES是對輸入做一些限定,OP_REQUIRES_OK還不清楚、
class RingPointGpuOp : public OpKernel {
    public:
        //這里是構造函數吧
        explicit RingPointGpuOp(OpKernelConstruction* context) : OpKernel(context) {
            OP_REQUIRES_OK(context, context->GetAttr("radius_in", &radius_in_));
            OP_REQUIRES(context, radius_in_ >= 0, errors::InvalidArgument("RingPoint expects positive inner radius"));

            OP_REQUIRES_OK(context, context->GetAttr("radius_out", &radius_out_));
            OP_REQUIRES(context, radius_out_ > 0, errors::InvalidArgument("RingPoint expects positive outter radius"));

            OP_REQUIRES_OK(context, context->GetAttr("nsample", &nsample_));
            OP_REQUIRES(context, nsample_ > 0, errors::InvalidArgument("RingPoint expects positive nsample"));
        }
        //重寫Compute方法
        void Compute(OpKernelContext* context) override {
            //獲取第0個輸入,輸入要獲取
            const Tensor& xyz1_tensor = context->input(0);
            OP_REQUIRES(context, xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3, errors::InvalidArgument("RingPoint expects (batch_size, ndataset, 3) xyz1 shape."));
       //獲得一些參數作為后續所調用的核函數(ringPointLauncher)的輸入
            int b = xyz1_tensor.shape().dim_size(0);
            int n = xyz1_tensor.shape().dim_size(1);

            const Tensor& xyz2_tensor = context->input(1);
            OP_REQUIRES(context, xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3, errors::InvalidArgument("RingPoint expects (batch_size, npoint, 3) xyz2 shape."));
            int m = xyz2_tensor.shape().dim_size(1);

            const Tensor& idx2_tensor = context->input(2);
            OP_REQUIRES(context, idx2_tensor.dims()==2, errors::InvalidArgument("RingPoint expects (batch_size, npoint) idx2 shape."));

       //給輸出分配內存,輸出要分配內存
            Tensor *idx_tensor = nullptr;
            OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,m,nsample_}, &idx_tensor));
            Tensor *pts_cnt_tensor = nullptr;
            OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,m}, &pts_cnt_tensor));

            //要傳給后續的函數,所以輸入先flat(方便取值),再賦給一個指針,最后傳給核函數
       auto xyz1_flat = xyz1_tensor.flat<float>();
            const float *xyz1 = &(xyz1_flat(0));
            auto xyz2_flat = xyz2_tensor.flat<float>();
            const float *xyz2 = &(xyz2_flat(0));
            auto idx2_flat = idx2_tensor.flat<int>();
            const int *idx2 = &(idx2_flat(0));
            auto idx_flat = idx_tensor->flat<int>();
            int *idx = &(idx_flat(0));
       //再flat和賦給指針
            auto pts_cnt_flat = pts_cnt_tensor->flat<int>();
            int *pts_cnt = &(pts_cnt_flat(0));
       //auto angles_flat = angles_tensor->flat<float>();
       //float *angles = &(angles_flat(0));
       //cudaMemset(angles, 0.0, sizeof(float)*b*m_q*k); //要給輸出設初值用cudaMenset,前面的一樣,先分配內存,再賦值給指針
            ringPointLauncher(b,n,m,radius_in_,radius_out_,nsample_,xyz1,xyz2,idx2,idx,pts_cnt);
        }
    private:
        float radius_in_;
        float radius_out_;
        int nsample_;
};
//指定設備
REGISTER_KERNEL_BUILDER(Name("RingPoint").Device(DEVICE_GPU), RingPointGpuOp);
View Code
  • cuda的核函數(具體實現部分,.cu文件里)
 
        
#include <cstdio>
#include <float.h>

// input: points (b,n,c), idx (b,m,nsample)
// output: out (b,m,nsample,c)
//__global__: 聲明在device(GPU)上執行,
//__ device__:在device上執行,單僅可以從device中調用,不可以和__global__同時用。就是在__global__函數里調用
__global__ void group_point_gpu(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out) {
    int batch_index = blockIdx.x;//16個block,每個block處理一個batch
   //取數據先加上自己的大小,定位到對應block?
    points += n*c*batch_index;
    idx += m*nsample*batch_index;
    out += m*nsample*c*batch_index;

    int index = threadIdx.x;
    int stride = blockDim.x; //srtide是整的grid的線程數,好像是為了並行啥的,這里gridDim就是1吧

    for (int j=index;j<m;j+=stride) {
        for (int k=0;k<nsample;++k) {
            int ii = idx[j*nsample+k];
            for (int l=0;l<c;++l) {
                out[j*nsample*c+k*c+l] = points[ii*c+l];
            }
        }
    }
}
 

void groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out){
   //kernel_fun<<< grid, block >>>(prams...); 來指定kernel要執行的線程數量
    group_point_gpu<<<b,256>>>(b,n,c,m,nsample,points,idx,out);
    //cudaDeviceSynchronize(); //同步device 保證結果能正確訪問
}
View Code
kernel是在device上線程中並行執行的函數,核函數用__global__符號聲明,在調用時需要用<<<grid, block>>>來指定kernel要執行的線程數量,在CUDA中,每一個線程都要執行核函數,並且每
個線程會分配一個唯一的線程號thread ID,這個ID值可以通過核函數的內置變量threadIdx來獲得。

kernel在device上執行時實際上是啟動很多線程,一個kernel所啟動的所有線程稱為一個網格(grid),同一個網格上的線程共享相同的全局內存空間,grid是線程結構的第一層次,而網格又可以
分為很多線程塊(block),一個線程塊里面包含很多線程,這是第二個層次。
  • 核函數只是分配不同的線程處理不同的數據,每個線程都會執行核函數,線程全局ID不同,通過ID值對數組進行索引,於是就能處理數組里不同的數了

一個線程需要兩個內置的坐標變量(blockIdx,threadIdx)來唯一標識,它們都是dim3類型變量,其中blockIdx指明線程所在grid中的位置,而threaIdx指明線程所在block中的位置。找到全局ID之后通過ID值對數組進行索引,處理數組的元素。

 

__global__ void order_neighbors_gpu(int b, int m, int n, int m_q, int k,int num_k, const float *input, const float *queries, const float *queries_norm, const int *idx,const float *kernel, float *proj, int *outi, float *angles,float *kernel_out){
  int batch_index = blockIdx.x;  //這個block里的線程就處理這一個batch的數據吧
  queries+=m_q*n*batch_index; //(512*3*16)//數組索引吧,blockIdx.x是0-15,定位到了某個batch的開頭,再取這個batch里的數據
  queries_norm+=m_q*n*batch_index;
  idx+=m_q*k*batch_index; //(16,512,16)
  angles+=m_q*k*batch_index;
  outi+=m_q*k*batch_index;
  input+=m*n*batch_index;
  proj+=m_q*k*n*batch_index;

  int index = threadIdx.x;
  int stride = blockDim.x; //block里含有的線程數,數據數超過線程數就用同一個線程處理

  // copy indecies from idx to outi
  for (int i=index; i<m_q; i+=stride) //把m_q個數據划分給stride個線程處理
      for (int j=0; j<k; ++j)
          outi[i*k + j] = idx[i*k + j];
...
}
View Code

數組的索引還是0-N,傳入的時候是flat過的,長度是所有元素個數。在元素數超過線程數的時候一個線程處理多個數據,用gride-stride loop方法,只處理一個數的話循環是不執行的。

 完整示例:

Attr:一個數,Input:輸入數組Output:輸出數組

//tf_ordering.cpp
REGISTER_OP("OrderNeighbors") .Attr("k: int") .Attr("sita: float") .Input("input_xyz: float32") .Input("query_xyz: float32") .Input("query_normals: float32") .Input("idx: int32") .Input("kernel: float32") .Output("outi: int32") .Output("kernel_fit: float32") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c){ int k; TF_RETURN_IF_ERROR(c->GetAttr("k", &k)); float sita; TF_RETURN_IF_ERROR(c->GetAttr("sita", &sita)); ::tensorflow::shape_inference::ShapeHandle dims2; c->WithRank(c->input(3), 3, &dims2); c->set_output(0, c->input(3)); ::tensorflow::shape_inference::ShapeHandle dims3; c->WithRank(c->input(4), 2, &dims3); ::tensorflow::shape_inference::ShapeHandle output1 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1), c->Dim(dims3, 0),4}); c->set_output(1, output1); return Status::OK(); }); void orderNeighborsLauncher(int b, int m, int n, int m_q, int k,int num_k,float sita, const float *input, const float *queries, const float *queries_norm, const int *idx, const float *kernel, float *proj, int *outi,float *kernel_fit); class OrderNeighborsGpuOp : public OpKernel { public: explicit OrderNeighborsGpuOp(OpKernelConstruction * context):OpKernel(context){ OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); OP_REQUIRES(context, k_ > 0, errors::InvalidArgument("OrderNeighbors expects positive k")); OP_REQUIRES_OK(context, context->GetAttr("sita", &sita_)); OP_REQUIRES(context, sita_ > 0, errors::InvalidArgument("OrderNeighbors expects positive sita")); } void Compute(OpKernelContext* context) override { const Tensor& input_xyz_tensor = context->input(0); OP_REQUIRES(context, input_xyz_tensor.dims() == 3, errors::InvalidArgument("OrderNeighbors expects (b,m,n) input_xyz shape")); int b = input_xyz_tensor.shape().dim_size(0); int m = input_xyz_tensor.shape().dim_size(1); int n = input_xyz_tensor.shape().dim_size(2);//n是特征維度? //3 const Tensor& query_xyz_tensor = context->input(1); OP_REQUIRES(context, query_xyz_tensor.dims() == 3, errors::InvalidArgument("OrderNeighbors expects (b,m_q,n) query_xyz shape")); int m_q = query_xyz_tensor.shape().dim_size(1); const Tensor& query_normals_tensor = context->input(2); OP_REQUIRES(context, query_normals_tensor.dims() == 3, errors::InvalidArgument("OrderNeighbors expects (b,m_q,n) query_normals shape")); const Tensor& idx_tensor = context->input(3); OP_REQUIRES(context, idx_tensor.dims() == 3, errors::InvalidArgument("OrderNeighbors expects (b,m_q,k) idx shape")); int k = idx_tensor.shape().dim_size(2); //我加的 const Tensor& kernel_tensor = context->input(4); OP_REQUIRES(context, kernel_tensor.dims() == 2, errors::InvalidArgument("OrderNeighbors expects (num_k,2) kernel shape")); int num_k=kernel_tensor.shape().dim_size(0); Tensor *outi_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,m_q,k}, &outi_tensor)); Tensor *kernel_fit_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,m_q,num_k,4}, &kernel_fit_tensor)); auto input_flat = input_xyz_tensor.flat<float>(); const float *input = &(input_flat(0)); auto queries_flat = query_xyz_tensor.flat<float>(); const float *queries = &(queries_flat(0)); auto queries_norm_flat = query_normals_tensor.flat<float>(); const float *queries_norm = &(queries_norm_flat(0)); auto idx_flat = idx_tensor.flat<int>(); const int *idx = &(idx_flat(0)); auto outi_flat = outi_tensor->flat<int>(); int *outi = &(outi_flat(0));//我加的 auto kernel_flat = kernel_tensor.flat<float>(); const float *kernel = &(kernel_flat(0)); auto kernel_fit_flat = kernel_fit_tensor->flat<float>(); float *kernel_fit = &(kernel_fit_flat(0)); cudaMemset(kernel_fit, 0.0, sizeof(float)*b*m_q*num_k*4); orderNeighborsLauncher(b, m, n, m_q, k,num_k,sita_, input, queries, queries_norm, idx, kernel, proj, outi,kernel_fit); } private: int k_; float sita_; }; REGISTER_KERNEL_BUILDER(Name("OrderNeighbors").Device(DEVICE_GPU), OrderNeighborsGpuOp);

 

參考:tensorflow新op官方教程

https://zhuanlan.zhihu.com/p/34168765

tensorflow新op實例 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM