pytorch中調用C進行擴展,使得某些功能在CPU上運行更快;
第一步:編寫頭文件
/* src/my_lib.h */ int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output); int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);
第二步:編寫源文件
/* src/my_lib.c */ #include <TH/TH.h> int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output) { if (!THFloatTensor_isSameSizeAs(input1, input2)) return 0; THFloatTensor_resizeAs(output, input1); THFloatTensor_cadd(output, input1, 1.0, input2); return 1; } int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input) { THFloatTensor_resizeAs(grad_input, grad_output); THFloatTensor_fill(grad_input, 1); return 1; }
注意:頭文件TH就是pytorch底層代碼的接口頭文件,它是CPU模式,GPU下則為THC;
第三步:在同級目錄下創建一個.py文件(比如叫“build.py”)
該文件用於對該C擴展模塊進行編譯(使用torch.util.ffi模塊進行擴展編譯);
# build.py from torch.utils.ffi import create_extension ffi = create_extension( name='_ext.my_lib', # 輸出文件地址及名稱 headers='src/my_lib.h', # 編譯.h文件地址及名稱 sources=['src/my_lib.c'], # 編譯.c文件地址及名稱 with_cuda=False # 不使用cuda ) ffi.build()
第四步:編寫.py腳本調用編譯好的C擴展模塊
import torch from torch.autograd import Function from _ext import my_lib import torch.nn as nn class MyAddFunction(Function): def forward(self, input1, input2): output = torch.FloatTensor() my_lib.my_lib_add_forward(input1, input2, output) return output def backward(self, grad_output): grad_input = torch.FloatTensor() my_lib.my_lib_add_backward(grad_input, grad_output) return grad_input class MyAddModule(nn.Module): def forward(self, input1, input2): return MyAddFunction()(input1, input2) class MyNetWork(nn.Module): def __init__(self): super(MyNetWork, self).__init__() self.add = MyAddModule() def forward(self, input1, input2): return self.add(input1, input2) model = MyNetWork() input1, input2 = torch.randn(5, 5), torch.randn(5, 5) print(model(input1, input2)) print(input1 + input2)
至此,用這個簡單的例子拋磚引玉~