pytorch faster_rcnn


代碼地址:https://github.com/jwyang/faster-rcnn.pytorch

 

1.fasterRCNN.train():這個不是讓網絡進行訓練,而是讓module in training mode,有些module在traing model和testing model下不同,比如bn

即self.training這個成員變量為true(這個成員變量屬於nn.Module,fasterRCNN繼承了這個成員變量),以下是train成員函數的源碼

2.bn的train和test不同,train的時候應該是要學習參數的,test的時候關閉,pytorch的用法如下:

pytorch的batchnorm使用時需要小心,training和track_running_stats可以組合出三種behavior,很容易掉坑里(我剛發現我對track_running_stats的理解錯了)。

  1. training=True, track_running_stats=True, 這是常用的training時期待的行為,running_mean 和running_var會跟蹤不同batch數據的mean和variance。
  2. training=True, track_running_stats=False, 這時候batchnorm不跟蹤跨batch數據的statistics了,而是用每個batch的mean和variance做normalization。
  3. training=False, track_running_stats=True, 這是我們期待的test時候的行為,即使用training階段估計的running_mean 和running_var.
  4. training=False, track_running_stats=False,同2(!!!).
https://www.zhihu.com/question/282672547/answer/529154567李韶華的回答
3.class_agnostic == true就是所有類別回歸同一個坐標,也就是一個框回歸一個坐標
        == false是每個類別單獨回歸4個坐標
    if self.class_agnostic:
      self.RCNN_bbox_pred = nn.Linear(4096, 4)
    else:
      self.RCNN_bbox_pred = nn.Linear(4096, 4 * self.n_classes)
4.真正開始訓練的代碼不是fasterRCNN.train(),而是下面這段代碼:
      rois, cls_prob, bbox_pred, \
      rpn_loss_cls, rpn_loss_box, \
      RCNN_loss_cls, RCNN_loss_bbox, \
      rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)

fasterRCNN是一個實例,應該是沒辦法進行調用的,但實際上這段代碼執行的是forward函數。為什么?其實就是python的括號重載。fasterRCNN這個實例繼承於nn.Module類,這個類定義了forward成員函數,nn.Module類使用了__call__進行了重載,讓實例能夠調用,並且調用的函數是forward函數,具體代碼見下面的源碼:

python中__call__函數的作用是使實例能夠像函數一樣被調用https://blog.csdn.net/Yaokai_AssultMaster/article/details/70256621,也稱之為括號重載,即‘()’

    def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            hook(self, input)
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                raise RuntimeError(
                    "forward hooks should never return any values, but '{}'"
                    "didn't return None".format(hook))
        if len(self._backward_hooks) > 0:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in self._backward_hooks.values():
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

 

nn.Module定義了一個forward的成員函數,這個函數在基類中沒有實現,而是在各個子類自己實現的,每個子類都必須實現forward函數:

    def forward(self, *input):
        r"""Defines the computation performed at every call.
        Should be overridden by all subclasses.
        .. note::
            Although the recipe for forward pass needs to be defined within
            this function, one should call the :class:`Module` instance afterwards
            instead of this since the former takes care of running the
            registered hooks while the latter silently ignores them.
        """
        raise NotImplementedError

 

子類調用forward函數不能直接用calss.forward(),而是用實例的函數調用,具體的原因好像是hook,這個在上面__call__函數中也看到調用forward使用了跟hook有關的input

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM