工作小結三


torch.max()輸入兩個tensor

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

最近看源代碼時候沒看懂騷操作

def find_intersection(set_1, set_2):
    """
    Find the intersection of every box combination between two sets of boxes that are in boundary coordinates.

    :param set_1: set 1, a tensor of dimensions (n1, 4)
    :param set_2: set 2, a tensor of dimensions (n2, 4)
    :return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2)
    """

    # PyTorch auto-broadcasts singleton dimensions
    lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0))  # (n1, n2, 2)
    upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0))  # (n1, n2, 2)
    intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0)  # (n1, n2, 2)
    return intersection_dims[:, :, 0] * intersection_dims[:, :, 1]  # (n1, n2)

那里說求交集應該是兩個邊界X距離--兩個框的寬度乘以兩個邊界Y距離--兩個框的寬度即可

原來問題出在torch.max()上,簡單的用法這里不再贅述,僅僅看最后一個用法,pytorch官方也是一筆帶過

torch.max(input, other, out=None) → Tensor
Each element of the tensor input is compared with the corresponding element of the tensor other and an element-wise maximum is taken.

The shapes of input and other don’t need to match, but they must be broadcastable.

\text{out}_i = \max(\text{tensor}_i, \text{other}_i)
out_i=max( tensor_i,other_i )
NOTE

When the shapes do not match, the shape of the returned output tensor follows the broadcasting rules.

Parameters
input (Tensor) – the input tensor.

other (Tensor) – the second input tensor

out (Tensor, optional) – the output tensor.

Example:

>>> a = torch.randn(4)
>>> a
tensor([ 0.2942, -0.7416,  0.2653, -0.1584])
>>> b = torch.randn(4)
>>> b
tensor([ 0.8722, -1.7421, -0.4141, -0.5055])
>>> torch.max(a, b)
tensor([ 0.8722, -0.7416,  0.2653, -0.1584])

正常如果如初兩個shape相同的tensor,直接按元素比較即可

如果兩個不同的tensor上面官方沒有說明:

這里舉個例子:輸入aaa=2 * 2,bbb=2 * 3

aaa = torch.randn(2,2)
bbb = torch.randn(3,2)
ccc = torch.max(aaa,bbb)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

出現以上的錯誤,這里先進行分析一下:

2 * 2 3 * 2無法直接進行比較,按照pytorch官方的說法逐元素比較,那么輸出也就應該是2 * 3 * 2,我們進一步進行測試:

aaa = torch.randn(1,2)
bbb = torch.randn(3,2)
ccc = torch.max(aaa,bbb)
tensor([[1.0350, 0.2532],
        [0.2203, 0.2532],
        [0.2912, 0.2532]])

直接可以輸出,不會報錯

原來pytorch的原則是這樣的:維度不同只能比較一維的數據

那么我們可以進一步測試,將輸入的2 * 23 * 2轉換成1 * 2 * 23 * 1 * 2

aaa = torch.randn(2,2).unsqueeze(1)
bbb = torch.randn(3,2).unsqueeze(0)
ccc = torch.max(aaa,bbb)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

好了,問題完美解決!有時間去看一下源代碼怎么實現的,咋不智能。。。。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM