pytorch中調用C進行擴展


pytorch中調用C進行擴展,使得某些功能在CPU上運行更快;

第一步:編寫頭文件

/* src/my_lib.h */
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);

 

第二步:編寫源文件

/* src/my_lib.c */
#include <TH/TH.h>

int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
THFloatTensor *output)
{
    if (!THFloatTensor_isSameSizeAs(input1, input2))
        return 0;
    THFloatTensor_resizeAs(output, input1);
    THFloatTensor_cadd(output, input1, 1.0, input2);
    return 1;
}

int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
{
    THFloatTensor_resizeAs(grad_input, grad_output);
    THFloatTensor_fill(grad_input, 1);
    return 1;
}

 

注意:頭文件TH就是pytorch底層代碼的接口頭文件,它是CPU模式,GPU下則為THC;

 

 第三步:在同級目錄下創建一個.py文件(比如叫“build.py”)

該文件用於對該C擴展模塊進行編譯(使用torch.util.ffi模塊進行擴展編譯);

# build.py
from torch.utils.ffi import create_extension
ffi = create_extension(
name='_ext.my_lib',        # 輸出文件地址及名稱
headers='src/my_lib.h',    # 編譯.h文件地址及名稱
sources=['src/my_lib.c'],  # 編譯.c文件地址及名稱
with_cuda=False            # 不使用cuda
)
ffi.build()

 

第四步:編寫.py腳本調用編譯好的C擴展模塊

import torch
from torch.autograd import Function
from _ext import my_lib
import torch.nn as nn

class MyAddFunction(Function):
    def forward(self, input1, input2):
        output = torch.FloatTensor()
        my_lib.my_lib_add_forward(input1, input2, output)
        return output

    def backward(self, grad_output):
        grad_input = torch.FloatTensor()
        my_lib.my_lib_add_backward(grad_input, grad_output)
        return grad_input

class MyAddModule(nn.Module):
    def forward(self, input1, input2):
        return MyAddFunction()(input1, input2)

class MyNetWork(nn.Module):
    def __init__(self):
        super(MyNetWork, self).__init__()
        self.add = MyAddModule()

    def forward(self, input1, input2):
        return self.add(input1, input2)

model = MyNetWork()
input1, input2 = torch.randn(5, 5), torch.randn(5, 5)
print(model(input1, input2))
print(input1 + input2)

 

至此,用這個簡單的例子拋磚引玉~

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM