前言
在模型預測過程中,如果將較大的待分類遙感影像直接輸入到網絡模型中會造成內存溢出,故一般將待分類圖像裁剪為一系列較小圖像分別輸入網絡進行預測,然后將預測結果按照裁剪順序拼接成一張最終結果圖像。
原理
如果采用常規的規則格網裁剪然后預測拼接的話效果不好。因為每張圖像塊的邊緣區域的上下文信息較少,所以預測結果精度較低,進而還會導致出現明顯的拼接痕跡。采用忽略邊緣預測,即有重疊地裁剪影像並在拼接時采取忽略邊緣策略。如圖1所示,實際裁剪圖像預測的結果為A ,進行拼接的結果為 a,a占 A的區域百分比為r ,相鄰裁剪圖像的重疊比例為。這里借用知乎大佬的圖來說明一下
代碼實現
我們先把大圖像裁剪成一系列與相鄰圖像塊有特定重復區域的圖像塊,並把它們存在鏈表里,然后創建生成器,之后進行預測。最后對預測結果只取中間部分進行拼接。代碼注釋寫得相對比較詳細,直接看代碼:
import math
import numpy as np
import torch.nn.functional as F
import torch
from osgeo import gdal
from unet import UNet
import torchvision
# 讀取tif數據集
def readTif(fileName, xoff=0, yoff=0, data_width=0, data_height=0):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName + "文件無法打開")
# 柵格矩陣的列數
width = dataset.RasterXSize
# 柵格矩陣的行數
height = dataset.RasterYSize
# 獲取數據
if (data_width == 0 and data_height == 0):
data_width = width
data_height = height
data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)
return data
# 保存tif文件函數
def writeTiff(fileName, data, im_geotrans=(0, 0, 0, 0, 0, 0), im_proj=""):
if 'int8' in data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(data.shape) == 3:
im_bands, im_height, im_width = data.shape
elif len(data.shape) == 2:
data = np.array([data])
im_bands, im_height, im_width = data.shape
# 創建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(fileName, int(im_width), int(im_height), int(im_bands), datatype)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 寫入仿射變換參數
dataset.SetProjection(im_proj) # 寫入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(data[i])
del dataset
# tif裁剪(tif像素數據,裁剪邊長)
def TifCroppingArray(img, SideLength):
# 裁剪鏈表
TifArrayReturn = []
# 列上圖像塊數目
ColumnNum = int((img.shape[0] - SideLength * 2) / (256 - SideLength * 2))
# 行上圖像塊數目
RowNum = int((img.shape[1] - SideLength * 2) / (256 - SideLength * 2))
for i in range(ColumnNum):
TifArray = []
for j in range(RowNum):
cropped = img[i * (256 - SideLength * 2): i * (256 - SideLength * 2) + 256,
j * (256 - SideLength * 2): j * (256 - SideLength * 2) + 256]
TifArray.append(cropped)
TifArrayReturn.append(TifArray)
# 考慮到行列會有剩余的情況,向前裁剪一行和一列
# 向前裁剪最后一列
for i in range(ColumnNum):
cropped = img[i * (256 - SideLength * 2): i * (256 - SideLength * 2) + 256,
(img.shape[1] - 256): img.shape[1]]
TifArrayReturn[i].append(cropped)
# 向前裁剪最后一行
TifArray = []
for j in range(RowNum):
cropped = img[(img.shape[0] - 256): img.shape[0],
j * (256 - SideLength * 2): j * (256 - SideLength * 2) + 256]
TifArray.append(cropped)
# 向前裁剪右下角
cropped = img[(img.shape[0] - 256): img.shape[0],
(img.shape[1] - 256): img.shape[1]]
TifArray.append(cropped)
TifArrayReturn.append(TifArray)
# 列上的剩余數
ColumnOver = (img.shape[0] - SideLength * 2) % (256 - SideLength * 2) + SideLength
# 行上的剩余數
RowOver = (img.shape[1] - SideLength * 2) % (256 - SideLength * 2) + SideLength
return TifArrayReturn, RowOver, ColumnOver
# 獲得結果矩陣
def Result(shape, TifArray, npyfile, RepetitiveLength, RowOver, ColumnOver):
result = np.zeros(shape, np.uint8)
# j來標記行數
j = 0
for i, img in enumerate(npyfile):
# 最左側一列特殊考慮,左邊的邊緣要拼接進去
if (i % len(TifArray[0]) == 0):
# 第一行的要再特殊考慮,上邊的邊緣要考慮進去
if (j == 0):
result[0: 256 - RepetitiveLength, 0: 256 - RepetitiveLength] = img[0: 256 - RepetitiveLength,
0: 256 - RepetitiveLength]
# 最后一行的要再特殊考慮,下邊的邊緣要考慮進去
elif (j == len(TifArray) - 1):
# 原來錯誤的
# result[shape[0] - ColumnOver : shape[0], 0 : 512 - RepetitiveLength] = img[0 : ColumnOver, 0 : 512 - RepetitiveLength]
# 后來修改的
result[shape[0] - ColumnOver - RepetitiveLength: shape[0], 0: 256 - RepetitiveLength] = img[
256 - ColumnOver - RepetitiveLength: 512,
0: 256 - RepetitiveLength]
else:
result[j * (256 - 2 * RepetitiveLength) + RepetitiveLength: (j + 1) * (
256 - 2 * RepetitiveLength) + RepetitiveLength,
0:256 - RepetitiveLength] = img[RepetitiveLength: 256 - RepetitiveLength, 0: 256 - RepetitiveLength]
# 最右側一列特殊考慮,右邊的邊緣要拼接進去
elif (i % len(TifArray[0]) == len(TifArray[0]) - 1):
# 第一行的要再特殊考慮,上邊的邊緣要考慮進去
if (j == 0):
result[0: 256 - RepetitiveLength, shape[1] - RowOver: shape[1]] = img[0: 256 - RepetitiveLength,
256 - RowOver: 256]
# 最后一行的要再特殊考慮,下邊的邊緣要考慮進去
elif (j == len(TifArray) - 1):
result[shape[0] - ColumnOver: shape[0], shape[1] - RowOver: shape[1]] = img[256 - ColumnOver: 256,
256 - RowOver: 256]
else:
result[j * (256 - 2 * RepetitiveLength) + RepetitiveLength: (j + 1) * (
256 - 2 * RepetitiveLength) + RepetitiveLength,
shape[1] - RowOver: shape[1]] = img[RepetitiveLength: 256 - RepetitiveLength, 256 - RowOver: 256]
# 走完每一行的最右側,行數+1
j = j + 1
# 不是最左側也不是最右側的情況
else:
# 第一行的要特殊考慮,上邊的邊緣要考慮進去
if (j == 0):
result[0: 256 - RepetitiveLength,
(i - j * len(TifArray[0])) * (256 - 2 * RepetitiveLength) + RepetitiveLength: (i - j * len(
TifArray[0]) + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength
] = img[0: 512 - RepetitiveLength, RepetitiveLength: 256 - RepetitiveLength]
# 最后一行的要特殊考慮,下邊的邊緣要考慮進去
if (j == len(TifArray) - 1):
result[shape[0] - ColumnOver: shape[0],
(i - j * len(TifArray[0])) * (256 - 2 * RepetitiveLength) + RepetitiveLength: (i - j * len(
TifArray[0]) + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength
] = img[256 - ColumnOver: 256, RepetitiveLength: 256 - RepetitiveLength]
else:
result[j * (256 - 2 * RepetitiveLength) + RepetitiveLength: (j + 1) * (
256 - 2 * RepetitiveLength) + RepetitiveLength,
(i - j * len(TifArray[0])) * (256 - 2 * RepetitiveLength) + RepetitiveLength: (i - j * len(
TifArray[0]) + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength,
] = img[RepetitiveLength: 256 - RepetitiveLength, RepetitiveLength: 256 - RepetitiveLength]
return result
area_perc = 0.5
TifPath = r"343.tif"
model_paths = [
r"MODEl.pth"
]
ResultPath = r"predict_result1.tif"
RepetitiveLength = int((1 - math.sqrt(area_perc)) * 256 / 2)
big_image = readTif(TifPath)
big_image = big_image.swapaxes(1, 0).swapaxes(1, 2)
#big_image = cv2.imread(TifPath, cv2.IMREAD_UNCHANGED)
TifArray, RowOver, ColumnOver = TifCroppingArray(big_image, RepetitiveLength)
# 改成自己的model即可
model = UNet(n_channels=3, n_classes=2, bilinear=False)
# 將模型加載到指定設備DEVICE上
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
predicts = []
for i in range(len(TifArray)):
for j in range(len(TifArray[0])):
image = TifArray[i][j]
img=torchvision.transforms.ToTensor()(image)
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
for model_path in model_paths:
model.load_state_dict(torch.load(model_path))
model.eval()
with torch.no_grad():
output = model(img)
if model.n_classes > 1:
probs = F.softmax(output, dim=1)[0]
else:
probs = torch.sigmoid(output)[0]
tf = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.Resize((image.shape[1], image.shape[0])),
torchvision.transforms.ToTensor()
])
mask = tf(probs.cpu()).squeeze()
if model.n_classes == 1:
mask =(mask > 0.5).numpy()
else:
mask=F.one_hot(mask.argmax(dim=0), model.n_classes).permute(2, 0, 1).numpy()
pred = mask[1]
predicts.append((pred))
# 保存結果predictspredicts
result_shape = (big_image.shape[0], big_image.shape[1])
result_data = Result(result_shape, TifArray, predicts, RepetitiveLength, RowOver, ColumnOver)
writeTiff(ResultPath, result_data)
參考文獻
王振慶,周藝,王世新,王福濤,徐知宇.2021.IEU-Net高分辨率遙感影像房屋建築物提取.遙感學報,25(11): 2245-2254 DOI: 10.11834/jrs.20210042. Wang Z Q,Zhou Y,Wang S X,Wang F T and Xu Z Y. 2021. House building extraction from high-resolution remote sensing images based on IEU-Net. National Remote Sensing Bulletin, 25(11):2245-2254 DOI: 10.11834/jrs.20210042.