tensorflow添加新操作(Op)


想要为点云加一个寻找k近邻的操作,好像只能通过写新Op实现,看了半天博客半懂不懂的,改改试试(对A-CNN里的ordering操作)

为了加入一个定制操作,你需要:

  1. c++ 文件中注册一个新opOp registration 定义了 op 的功能接口,它和 op 的实现是独立的。例如:op registration 定义了 op 的名字和 op的输出输出。它同时也定义了 shape 方法,被用于 tensorshape 接口。
  2. c++ 中实现 opop 的实现称之为 kernel(cuda核函数) ,它是op 的一个具体实现。对于不同的输入输出类型或者 架构(CPUs,GPUs)可以有不同的 kernel 实现 。
  3. 创建一个 python wrapper(可选的): 这个 wrapper 是一个 公开的 API,用来在 python中创建 opop registration 会生成一个默认的 wrapper,我们可以直接使用或者自己添加一个。
  4. 写一个计算 op 梯度的方法(可选)。
  5. 测试 op:为了方便,我们通常在 python 中测试 op,但是你也可以在 c++ 中进行测试。如果你定义了 gradients,你可以 通过 Pythongradient checker 验证他们。 这里有个例子relu_op_test.py ,测试 ReLU-likeop 的 前向和梯度过程。

1. 定义接口(.cpp里):

在注册 op 的时候,你需要指定:

  • op 的名字
  • op 的输入(名字,类型),op 的输出(名字,类型)
  • docstrings
  • op 可能需要的 一些 attrs
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include <cuda_runtime.h>

using namespace tensorflow;

REGISTER_OP("RingPoint")
    .Attr("radius_in: float")
    .Attr("radius_out: float")
    .Attr("nsample: int")
    .Input("xyz1: float32")
    .Input("xyz2: float32")
    .Input("idx2: int32")
    .Input("kernel: float32")
    .Output("idx: int32")
    .Output("pts_cnt: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
        ::tensorflow::shape_inference::ShapeHandle dims2;
        c->WithRank(c->input(1), 3, &dims2);       //把输入维度取出来放到dim2里吧
        int nsample;
        TF_RETURN_IF_ERROR(c->GetAttr("nsample", &nsample));
        ::tensorflow::shape_inference::ShapeHandle output1 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1), nsample}); //和输入维度不一样就要自己设一下
       c->set_output(0, output1);                 //然后把自己设的输出维度和第0个输出对应
        ::tensorflow::shape_inference::ShapeHandle output2 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1)});
     //c->set_output(0, c->input(0));           //输出维度和输入维度一样就用输入的维度
        c->set_output(1, output2);         
        return Status::OK();
    });
View Code

 

关于命名的备注:操作名称必须首字母大写,而且不能和库中已经注册的其它操作重名。

2. 实现操作的内核(.cpp中)

定义接口后,接下来就需要为此操作提供一个或多个内核实现了。
为了实现这些内核,创建一个继承自 OpKernel 的类,并重载 Compute 方法。
Compute 方法有一个类型为 OpKernelContext* 的参数 context,从中可以访问输入和输出张量等有用的信息。

  • CPU版本:
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // 得到输入张量
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // 创建输出张量
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // 除第一个元素外,输出张量的其它所有元素都设置为 0 
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // 如果可能的话,保留第一个输入值
    if (N > 0) output_flat(0) = input(0);
  }
};
View Code

 

ZeroOut 操作加上约束条件:

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

这里注册的操作名是ZeroOut,通过上面的语句和ZeroOutOp对应吧

  • GPU版本:重写的compute里一般都只是数据的处理:大小推断、flat成一维、分配内存、设初值等等,具体的计算用cuda的核函数实现
//先声明
void ringPointLauncher(int b, int n, int m, float radius_in, float radius_out, int nsample, const float *xyz1, const float *xyz2, const int * idx2, int *idx, int *pts_cnt);
//再继承OpKernel,OP_REQUIRES是对输入做一些限定,OP_REQUIRES_OK还不清楚、
class RingPointGpuOp : public OpKernel {
    public:
        //这里是构造函数吧
        explicit RingPointGpuOp(OpKernelConstruction* context) : OpKernel(context) {
            OP_REQUIRES_OK(context, context->GetAttr("radius_in", &radius_in_));
            OP_REQUIRES(context, radius_in_ >= 0, errors::InvalidArgument("RingPoint expects positive inner radius"));

            OP_REQUIRES_OK(context, context->GetAttr("radius_out", &radius_out_));
            OP_REQUIRES(context, radius_out_ > 0, errors::InvalidArgument("RingPoint expects positive outter radius"));

            OP_REQUIRES_OK(context, context->GetAttr("nsample", &nsample_));
            OP_REQUIRES(context, nsample_ > 0, errors::InvalidArgument("RingPoint expects positive nsample"));
        }
        //重写Compute方法
        void Compute(OpKernelContext* context) override {
            //获取第0个输入,输入要获取
            const Tensor& xyz1_tensor = context->input(0);
            OP_REQUIRES(context, xyz1_tensor.dims()==3 && xyz1_tensor.shape().dim_size(2)==3, errors::InvalidArgument("RingPoint expects (batch_size, ndataset, 3) xyz1 shape."));
       //获得一些参数作为后续所调用的核函数(ringPointLauncher)的输入
            int b = xyz1_tensor.shape().dim_size(0);
            int n = xyz1_tensor.shape().dim_size(1);

            const Tensor& xyz2_tensor = context->input(1);
            OP_REQUIRES(context, xyz2_tensor.dims()==3 && xyz2_tensor.shape().dim_size(2)==3, errors::InvalidArgument("RingPoint expects (batch_size, npoint, 3) xyz2 shape."));
            int m = xyz2_tensor.shape().dim_size(1);

            const Tensor& idx2_tensor = context->input(2);
            OP_REQUIRES(context, idx2_tensor.dims()==2, errors::InvalidArgument("RingPoint expects (batch_size, npoint) idx2 shape."));

       //给输出分配内存,输出要分配内存
            Tensor *idx_tensor = nullptr;
            OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,m,nsample_}, &idx_tensor));
            Tensor *pts_cnt_tensor = nullptr;
            OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,m}, &pts_cnt_tensor));

            //要传给后续的函数,所以输入先flat(方便取值),再赋给一个指针,最后传给核函数
       auto xyz1_flat = xyz1_tensor.flat<float>();
            const float *xyz1 = &(xyz1_flat(0));
            auto xyz2_flat = xyz2_tensor.flat<float>();
            const float *xyz2 = &(xyz2_flat(0));
            auto idx2_flat = idx2_tensor.flat<int>();
            const int *idx2 = &(idx2_flat(0));
            auto idx_flat = idx_tensor->flat<int>();
            int *idx = &(idx_flat(0));
       //再flat和赋给指针
            auto pts_cnt_flat = pts_cnt_tensor->flat<int>();
            int *pts_cnt = &(pts_cnt_flat(0));
       //auto angles_flat = angles_tensor->flat<float>();
       //float *angles = &(angles_flat(0));
       //cudaMemset(angles, 0.0, sizeof(float)*b*m_q*k); //要给输出设初值用cudaMenset,前面的一样,先分配内存,再赋值给指针
            ringPointLauncher(b,n,m,radius_in_,radius_out_,nsample_,xyz1,xyz2,idx2,idx,pts_cnt);
        }
    private:
        float radius_in_;
        float radius_out_;
        int nsample_;
};
//指定设备
REGISTER_KERNEL_BUILDER(Name("RingPoint").Device(DEVICE_GPU), RingPointGpuOp);
View Code
  • cuda的核函数(具体实现部分,.cu文件里)
 
#include <cstdio>
#include <float.h>

// input: points (b,n,c), idx (b,m,nsample)
// output: out (b,m,nsample,c)
//__global__: 声明在device(GPU)上执行,
//__ device__:在device上执行,单仅可以从device中调用,不可以和__global__同时用。就是在__global__函数里调用
__global__ void group_point_gpu(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out) {
    int batch_index = blockIdx.x;//16个block,每个block处理一个batch
   //取数据先加上自己的大小,定位到对应block?
    points += n*c*batch_index;
    idx += m*nsample*batch_index;
    out += m*nsample*c*batch_index;

    int index = threadIdx.x;
    int stride = blockDim.x; //srtide是整的grid的线程数,好像是为了并行啥的,这里gridDim就是1吧

    for (int j=index;j<m;j+=stride) {
        for (int k=0;k<nsample;++k) {
            int ii = idx[j*nsample+k];
            for (int l=0;l<c;++l) {
                out[j*nsample*c+k*c+l] = points[ii*c+l];
            }
        }
    }
}
 

void groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out){
   //kernel_fun<<< grid, block >>>(prams...); 来指定kernel要执行的线程数量
    group_point_gpu<<<b,256>>>(b,n,c,m,nsample,points,idx,out);
    //cudaDeviceSynchronize(); //同步device 保证结果能正确访问
}
View Code
kernel是在device上线程中并行执行的函数,核函数用__global__符号声明,在调用时需要用<<<grid, block>>>来指定kernel要执行的线程数量,在CUDA中,每一个线程都要执行核函数,并且每
个线程会分配一个唯一的线程号thread ID,这个ID值可以通过核函数的内置变量threadIdx来获得。

kernel在device上执行时实际上是启动很多线程,一个kernel所启动的所有线程称为一个网格(grid),同一个网格上的线程共享相同的全局内存空间,grid是线程结构的第一层次,而网格又可以
分为很多线程块(block),一个线程块里面包含很多线程,这是第二个层次。
  • 核函数只是分配不同的线程处理不同的数据,每个线程都会执行核函数,线程全局ID不同,通过ID值对数组进行索引,于是就能处理数组里不同的数了

一个线程需要两个内置的坐标变量(blockIdx,threadIdx)来唯一标识,它们都是dim3类型变量,其中blockIdx指明线程所在grid中的位置,而threaIdx指明线程所在block中的位置。找到全局ID之后通过ID值对数组进行索引,处理数组的元素。

 

__global__ void order_neighbors_gpu(int b, int m, int n, int m_q, int k,int num_k, const float *input, const float *queries, const float *queries_norm, const int *idx,const float *kernel, float *proj, int *outi, float *angles,float *kernel_out){
  int batch_index = blockIdx.x;  //这个block里的线程就处理这一个batch的数据吧
  queries+=m_q*n*batch_index; //(512*3*16)//数组索引吧,blockIdx.x是0-15,定位到了某个batch的开头,再取这个batch里的数据
  queries_norm+=m_q*n*batch_index;
  idx+=m_q*k*batch_index; //(16,512,16)
  angles+=m_q*k*batch_index;
  outi+=m_q*k*batch_index;
  input+=m*n*batch_index;
  proj+=m_q*k*n*batch_index;

  int index = threadIdx.x;
  int stride = blockDim.x; //block里含有的线程数,数据数超过线程数就用同一个线程处理

  // copy indecies from idx to outi
  for (int i=index; i<m_q; i+=stride) //把m_q个数据划分给stride个线程处理
      for (int j=0; j<k; ++j)
          outi[i*k + j] = idx[i*k + j];
...
}
View Code

数组的索引还是0-N,传入的时候是flat过的,长度是所有元素个数。在元素数超过线程数的时候一个线程处理多个数据,用gride-stride loop方法,只处理一个数的话循环是不执行的。

 完整示例:

Attr:一个数,Input:输入数组Output:输出数组

//tf_ordering.cpp
REGISTER_OP("OrderNeighbors") .Attr("k: int") .Attr("sita: float") .Input("input_xyz: float32") .Input("query_xyz: float32") .Input("query_normals: float32") .Input("idx: int32") .Input("kernel: float32") .Output("outi: int32") .Output("kernel_fit: float32") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c){ int k; TF_RETURN_IF_ERROR(c->GetAttr("k", &k)); float sita; TF_RETURN_IF_ERROR(c->GetAttr("sita", &sita)); ::tensorflow::shape_inference::ShapeHandle dims2; c->WithRank(c->input(3), 3, &dims2); c->set_output(0, c->input(3)); ::tensorflow::shape_inference::ShapeHandle dims3; c->WithRank(c->input(4), 2, &dims3); ::tensorflow::shape_inference::ShapeHandle output1 = c->MakeShape({c->Dim(dims2, 0), c->Dim(dims2, 1), c->Dim(dims3, 0),4}); c->set_output(1, output1); return Status::OK(); }); void orderNeighborsLauncher(int b, int m, int n, int m_q, int k,int num_k,float sita, const float *input, const float *queries, const float *queries_norm, const int *idx, const float *kernel, float *proj, int *outi,float *kernel_fit); class OrderNeighborsGpuOp : public OpKernel { public: explicit OrderNeighborsGpuOp(OpKernelConstruction * context):OpKernel(context){ OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); OP_REQUIRES(context, k_ > 0, errors::InvalidArgument("OrderNeighbors expects positive k")); OP_REQUIRES_OK(context, context->GetAttr("sita", &sita_)); OP_REQUIRES(context, sita_ > 0, errors::InvalidArgument("OrderNeighbors expects positive sita")); } void Compute(OpKernelContext* context) override { const Tensor& input_xyz_tensor = context->input(0); OP_REQUIRES(context, input_xyz_tensor.dims() == 3, errors::InvalidArgument("OrderNeighbors expects (b,m,n) input_xyz shape")); int b = input_xyz_tensor.shape().dim_size(0); int m = input_xyz_tensor.shape().dim_size(1); int n = input_xyz_tensor.shape().dim_size(2);//n是特征维度? //3 const Tensor& query_xyz_tensor = context->input(1); OP_REQUIRES(context, query_xyz_tensor.dims() == 3, errors::InvalidArgument("OrderNeighbors expects (b,m_q,n) query_xyz shape")); int m_q = query_xyz_tensor.shape().dim_size(1); const Tensor& query_normals_tensor = context->input(2); OP_REQUIRES(context, query_normals_tensor.dims() == 3, errors::InvalidArgument("OrderNeighbors expects (b,m_q,n) query_normals shape")); const Tensor& idx_tensor = context->input(3); OP_REQUIRES(context, idx_tensor.dims() == 3, errors::InvalidArgument("OrderNeighbors expects (b,m_q,k) idx shape")); int k = idx_tensor.shape().dim_size(2); //我加的 const Tensor& kernel_tensor = context->input(4); OP_REQUIRES(context, kernel_tensor.dims() == 2, errors::InvalidArgument("OrderNeighbors expects (num_k,2) kernel shape")); int num_k=kernel_tensor.shape().dim_size(0); Tensor *outi_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape{b,m_q,k}, &outi_tensor)); Tensor *kernel_fit_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape{b,m_q,num_k,4}, &kernel_fit_tensor)); auto input_flat = input_xyz_tensor.flat<float>(); const float *input = &(input_flat(0)); auto queries_flat = query_xyz_tensor.flat<float>(); const float *queries = &(queries_flat(0)); auto queries_norm_flat = query_normals_tensor.flat<float>(); const float *queries_norm = &(queries_norm_flat(0)); auto idx_flat = idx_tensor.flat<int>(); const int *idx = &(idx_flat(0)); auto outi_flat = outi_tensor->flat<int>(); int *outi = &(outi_flat(0));//我加的 auto kernel_flat = kernel_tensor.flat<float>(); const float *kernel = &(kernel_flat(0)); auto kernel_fit_flat = kernel_fit_tensor->flat<float>(); float *kernel_fit = &(kernel_fit_flat(0)); cudaMemset(kernel_fit, 0.0, sizeof(float)*b*m_q*num_k*4); orderNeighborsLauncher(b, m, n, m_q, k,num_k,sita_, input, queries, queries_norm, idx, kernel, proj, outi,kernel_fit); } private: int k_; float sita_; }; REGISTER_KERNEL_BUILDER(Name("OrderNeighbors").Device(DEVICE_GPU), OrderNeighborsGpuOp);

 

参考:tensorflow新op官方教程

https://zhuanlan.zhihu.com/p/34168765

tensorflow新op实例 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM