Pytorch的conv2d实现图像边缘检测和均值模糊


Pytorch的conv2d实现图像边缘检测和均值模糊

代码如下:

# 图像处理
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F

class GaussianBlur(nn.Module):
    def __init__(self):
        super(GaussianBlur,self).__init__()
        # 三个不同的卷积核
        # kernel_1和kernel2均为边缘检测,kernel_3为均值模糊
        kernel_1 = [[-1, -1, -1],
                  [-1,  8, -1],
                  [-1, -1, -1]]

        kernel_2 = [[ 0, -1,  0],
                  [-1,  4, -1],
                  [ 0, -1,  0]]

        kernel_3 = [[ 1/9,  1/9,  1/9],
                  [ 1/9,  1/9,  1/9],
                  [ 1/9,  1/9,  1/9]]
        
        kernel = torch.Tensor(kernel_1).unsqueeze(0).unsqueeze(0)
        self.weight = nn.Parameter(data=kernel,requires_grad=False)
        
    def forward(self,x):
        x1 = x[:,0]
        x2 = x[:,1]
        x3 = x[:,2]
        x1 = F.conv2d(x1.unsqueeze(1), self.weight, padding=1)
        x2 = F.conv2d(x2.unsqueeze(1), self.weight, padding=1)
        x3 = F.conv2d(x3.unsqueeze(1), self.weight, padding=1)
        x = torch.cat([x1, x2, x3], dim=1)
        return x

if __name__ == "__main__":
    #图像读取及处理
    img = cv2.imread("D:\\code\\python\\deeplearning\\cnn\\img\\img2.jpg") #img1.jpg | img2.jpg | img3.jpg
    img = torch.Tensor(img)
    img = img.permute(2,0,1).reshape(1,3,img.size(0),img.size(1))

    #边缘检测|均值模糊
    gb = GaussianBlur()
    res = gb.forward(img)

    #处理结果并显示
    res = res.permute(0,2,3,1).numpy().reshape(res.size(2),res.size(3),3)
    print(res.shape)
    cv2.imwrite('D:\\code\\python\\deeplearning\\cnn\\temp\\test.jpg',res)
    show = cv2.imread('D:\\code\\python\\deeplearning\\cnn\\temp\\test.jpg')
    cv2.imshow('img',show)
    cv2.waitKey(0)


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM