tensorflow添加自定義的auc計算operator


tensorflow可以很方便的添加用戶自定義的operator(如果不添加也可以采用sklearnauc計算函數或者自己寫一個 但是會在python執行,這里希望在graph中也就是c++端執行這個計算)

這里根據工作需要添加一個計算aucoperator,只給出最簡單實現,后續高級功能還是參考官方wiki

https://www.tensorflow.org/versions/r0.7/how_tos/adding_an_op/index.html

注意tensorflow現在和最初的官方wiki有變化,原wiki貌似是需要重新bazel編譯整個tensorflow,然后使用比如tf.user_op.auc這樣。

目前wiki給出的方式>=0.6.0版本,采用plug-in的方式,更加靈活可以直接用g++編譯一個so載入,解耦合,省去了編譯tensorflow過程,即插即用。

   

首先aucoperator計算的文件

   

tensorflow/core/user_ops/auc.cc

   

/* Copyright 2015 Google Inc. All Rights Reserved.

   

Licensed under the Apache License, Version 2.0 (the "License");

you may not use this file except in compliance with the License.

You may obtain a copy of the License at

   

http://www.apache.org/licenses/LICENSE-2.0

   

Unless required by applicable law or agreed to in writing, software

distributed under the License is distributed on an "AS IS" BASIS,

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and

limitations under the License.

==============================================================================*/

   

// An auc Op.

   

#include "tensorflow/core/framework/op.h"

#include "tensorflow/core/framework/op_kernel.h"

   

using namespace tensorflow;

using std::vector;

//@TODO add weight as optional input

REGISTER_OP("Auc")

.Input("predicts: T1")

.Input("labels: T2")

.Output("z: float")

.Attr("T1: {float, double}")

.Attr("T2: {float, double}")

//.Attr("T1: {float, double}")

//.Attr("T2: {int32, int64}")

.SetIsCommutative()

.Doc(R"doc(

Given preidicts and labels output it's auc

)doc");

   

class AucOp : public OpKernel {

public:

explicit AucOp(OpKernelConstruction* context) : OpKernel(context) {}

   

template<typename ValueVec>

void index_sort(const ValueVec& valueVec, vector<int>& indexVec)

{

indexVec.resize(valueVec.size());

for (size_t i = 0; i < indexVec.size(); i++)

{

indexVec[i] = i;

}

std::sort(indexVec.begin(), indexVec.end(),

[&valueVec](const int l, const int r) { return valueVec(l) > valueVec(r); });

}

   

void Compute(OpKernelContext* context) override {

// Grab the input tensor

const Tensor& predicts_tensor = context->input(0);

const Tensor& labels_tensor = context->input(1);

auto predicts = predicts_tensor.flat<float>(); //輸入能接受float double那么這里如何都處理?

auto labels = labels_tensor.flat<float>();

   

vector<int> indexes;

index_sort(predicts, indexes);

typedef float Float;

   

Float oldFalsePos = 0;

Float oldTruePos = 0;

Float falsePos = 0;

Float truePos = 0;

Float oldOut = std::numeric_limits<Float>::infinity();

Float result = 0;

   

for (size_t i = 0; i < indexes.size(); i++)

{

int index = indexes[i];

Float label = labels(index);

Float prediction = predicts(index);

Float weight = 1.0;

//Pval3(label, output, weight);

if (prediction != oldOut) //存在相同值得情況是特殊處理的

{

result += 0.5 * (oldTruePos + truePos) * (falsePos - oldFalsePos);

oldOut = prediction;

oldFalsePos = falsePos;

oldTruePos = truePos;

}

if (label > 0)

truePos += weight;

else

falsePos += weight;

}

result += 0.5 * (oldTruePos + truePos) * (falsePos - oldFalsePos);

Float AUC = result / (truePos * falsePos);

   

// Create an output tensor

Tensor* output_tensor = NULL;

TensorShape output_shape;

   

OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));

output_tensor->scalar<float>()() = AUC;

}

};

   

REGISTER_KERNEL_BUILDER(Name("Auc").Device(DEVICE_CPU), AucOp);

   

   

編譯:

$cat gen-so.sh

   

TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')

TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')

i=$1

o=${i/.cc/.so}

g++ -std=c++11 -shared $i -o $o -I $TF_INC -l tensorflow_framework -L $TF_LIB -fPIC -Wl,-rpath $TF_LIB

   

$sh gen-so.sh auc.cc

會生成auc.so

   

使用的時候

auc_module = tf.load_op_library('auc.so')

#auc = tf.user_ops.auc #0.6.0之前的tensorflow 自定義op方式

auc = auc_module.auc

   

evaluate_op = auc(py_x, Y) #py_x is predicts, Y is labels

   

   

   

   

   

   


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM