【pytorch->mindspore】1.自定義算子遷移


要遷移的項目為圖像壓縮算法https://github.com/ywz978020607/HESIC
1.自定義算子遷移--LowerBoundFunction類
為了能夠准確遷移底層封裝的類,需要詳細測試原版類以及遷移測試
pytorch中自定義的算子有torch.autograd.Function

import torch
import torch.nn as nn


class LowerBoundFunction(torch.autograd.Function):
    """Autograd function for the `LowerBound` operator.
    """
    @staticmethod
    def forward(ctx, input_, bound):
        ctx.save_for_backward(input_, bound)
        return torch.max(input_, bound)

    @staticmethod
    def backward(ctx, grad_output):
        input_, bound = ctx.saved_tensors
        pass_through_if = (input_ >= bound) | (grad_output < 0)
        print(grad_output) #tensor([ 0.,  2., 15.], grad_fn=<MulBackward0>)
        print(pass_through_if)
        print(pass_through_if.type(grad_output.dtype) * grad_output)
        return pass_through_if.type(grad_output.dtype) * grad_output, None

if __name__=="__main__":
    a = torch.Tensor([1,2,3])
    b = torch.Tensor([0,1,5])
    a.requires_grad_(True)
    b.requires_grad_(True)
    c = a*b

    m = LowerBoundFunction.apply(a,b)
    m.backward(c)

輸出為

tensor([ 0.,  2., 15.], grad_fn=<MulBackward0>)
tensor([ True,  True, False])
tensor([0., 2., 0.])

通過兩行print測試后發現,這個類用於阻斷梯度,有點類似Relu的感覺
而mindspore的自定義算子在昇騰、GPU、CPU下定義不同且過於復雜,咨詢hw工程師后,准備繼承nn.Cell並重載bprop函數實現,測試bprop反向梯度傳播如下

# https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/pynative_mode/test_hook.py#
import numpy as np
import pytest

import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import context, Tensor, ParameterTuple
from mindspore.common.initializer import TruncatedNormal
from mindspore.nn import WithLossCell, Momentum
from mindspore.ops import composite as C

context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")

grad_all = C.GradOperation(get_all=True)
bprop_debug = False

class MulAdd(nn.Cell):
    def __init__(self):
        super(MulAdd, self).__init__()

    def construct(self, x, y):
        return 2 * x * x + y * y

    def bprop(self, x, y, out, dout):
        global bprop_debug
        bprop_debug = True
        print(x)
        print(y)
        print(out)
        print(dout)
        # [1 2 3]
        # [2 3 5]
        # [ 6 17 43]
        # [1 1 1]
        return dout, 2 * y

def test_custom_bprop():
    mul_add = MulAdd()
    mul_add.bprop_debug = True
    x = Tensor(np.array([1, 2, 3]).astype(np.int32))
    y = Tensor(np.array([2, 3, 5]).astype(np.int32))
    ret = grad_all(mul_add)(x, y)
    print(ret) #(Tensor(shape=[3], dtype=Int32, value= [1, 1, 1]), Tensor(shape=[3], dtype=Int32, value= [ 4,  6, 10]))
    assert bprop_debug

##############
#ywz
test_custom_bprop()
print(bprop_debug)

測試通bprop重載的原理后,實現相應的類

import numpy as np
import pytest

import mindspore as msp

import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import context, Tensor, ParameterTuple
from mindspore.common.initializer import TruncatedNormal
from mindspore.nn import WithLossCell, Momentum
from mindspore.ops import composite as C

context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")

class LowerBoundFunction(nn.Cell):
    def __init__(self):
        super(LowerBoundFunction, self).__init__()


    def construct(self, input_, bound):
        return msp.ops.maximum(input_,bound)

    def bprop(self, input_, bound, out, dout): #dout是梯度  out是推理值
        # pass_through_if = (input_ >= bound) | (dout < 0)
        pass_through_if = ((input_ >= bound).astype(input_.dtype) + (dout < 0).astype(input_.dtype)).astype('Bool')
        # print(pass_through_if)

        # print(out) #out為construct結果  #pytorch此處是已經求完導的值,但msp中是正向推導的值
        # print(dout) #dout貌似是梯度
        # print("---")
        print(pass_through_if.astype(dout.dtype) * dout)
        return pass_through_if.astype(dout.dtype) * dout, pass_through_if.astype(dout.dtype) * dout #第二個用不上


if __name__=="__main__":
    grad_all = C.GradOperation(get_all=True)

    lowerboundfunc = LowerBoundFunction()
    x = Tensor(np.array([1, 2, 3]).astype(np.int32))
    y = Tensor(np.array([0, 1, 5]).astype(np.int32))
    test = lowerboundfunc(x, y)
    ret = grad_all(lowerboundfunc)(x, y)
    # print(ret)
    # print(lowerboundfunc)

    # p = (x >= y)
    # p.astype(x.dtype)

總結:pytorch的backward()函數更像黑盒子,梯度傳播計算都涵蓋在內,最終不顯式輸出;而msp的bprop重載時需要注意的點比較多,out是正向推理值,dout是梯度值。
參考:https://www.mindspore.cn/doc/api_python/zh-CN/r1.2/_modules/mindspore/nn/cell.html#Cell.cast_param


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM