[torch] torch.contiguous


torch.contiguous

作用

連續存儲,因為view的操作要求的是連續的內容。

詳細

考慮下面的操作,transpose操作只是改變了stride,而實際數組存儲的內容並沒有得到任何改變,即t是連續存儲的 0 1 2 3 4 5 6 7 8 9 10 11 ,t2的實際內容也是一致的,但是其索引的stride改變了,按照該索引去找地址則內存是不連續的。由於pytorch的底層實現是C,也就是行優先存儲.由最后輸出的faltten后的結果可以看出存儲的內容確實改變了,由此完全弄懂了為什么有的時候要contiguous。

>>>t = torch.arange(12).reshape(3,4)
>>>t
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
>>>t.stride()
(4, 1)
>>>t2 = t.transpose(0,1)
>>>t2
tensor([[ 0,  4,  8],
        [ 1,  5,  9],
        [ 2,  6, 10],
        [ 3,  7, 11]])
>>>t2.stride()
(1, 4)
>>>t.data_ptr() == t2.data_ptr() # 底層數據是同一個一維數組
True
>>>t.is_contiguous(),t2.is_contiguous() # t連續,t2不連續
(True, False)
>>>print(t1.flatten())
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
>>>t2 = t2.contiguous()
>>>print(t2.flatten())
tensor([ 0,  4,  8,  1,  5,  9,  2,  6, 10,  3,  7, 11])

應用

shuffleNet里打亂channel的操作

def shuffle_channels(x,groups):
    batch_size,channels,height,width = x.size()
    assert channels % groups == 0
    channels_per_group = channels // groups
    x = x.view(batch_size,groups,channels_per_group,height,width)
    x = x.transpose(1,2).contiguous()
    x = x.view(batch_size,channels,height,width)
    return x


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM