Faster RCNN PyTorch 代码解读


代码来源:

https://github.com/jwyang/faster-rcnn.pytorch

 

一、EasyDict 和 yaml

作者是使用EasyDict进行导入参数的

# `pip install easydict` if you don't have it
from easydict import EasyDict as edict

__C = edict()
# Consumers can get config by:
#   from fast_rcnn_config import cfg
cfg = __C

  

EasyDict,简单字典,就是将字典的索引方式由 myDict['key'] 的方式简化成 myDict.key

from easydict import EasyDict as edict

myedict = edict()

cfg = myedict


myedict.a = 100

myedict.TRAIN = edict()
myedict.TRAIN.LEARNING_RATE = 0.001
myedict.TRAIN.MOMENTUM = 0.9


print(cfg.a)
print(cfg.TRAIN.LEARNING_RATE)

  

注意这里的cfg和mydict是赋值引用,即cfg和myedict为同一个对象,myedict改变,cfg则会产生同样的改变。

 

如果字典本身的所有key都是string,则可以直接将其转化成EasyDict,myEasyDict = EasyDict(myDict)

 

作者把预设的参数也用yaml格式以文件格式存储了

EXP_DIR: vgg16
TRAIN:
  HAS_RPN: True
  BBOX_NORMALIZE_TARGETS_PRECOMPUTED: True
  RPN_POSITIVE_OVERLAP: 0.7
  RPN_BATCHSIZE: 256
  PROPOSAL_METHOD: gt
  BG_THRESH_LO: 0.0
  BATCH_SIZE: 256
  LEARNING_RATE: 0.01
TEST:
  HAS_RPN: True
POOLING_MODE: align
CROP_RESIZE_WITH_MAX_POOL: False

  

yaml文件格式要求如下:

  • 区分大小写;
  • 使用缩进表示层级关系;
  • 使用空格键缩进,而非Tab键缩进
  • 缩进的空格数目不固定,只需要相同层级的元素左侧对齐;
  • 文件中的字符串不需要使用引号标注,但若字符串包含有特殊字符则需用引号标注;
  • 注释标识为#

 

将yaml格式的文件导入,可以这样做:

import yaml
from easydict import EasyDict as edict


with open("vgg16.yml", 'r') as f:
    yaml_cfg1 = yaml.load(f, Loader=yaml.FullLoader)

with open("vgg16.yml", 'r') as f:
    yaml_cfg2 = edict(yaml.load(f, Loader=yaml.FullLoader))


print(yaml_cfg1) # 字典
print(yaml_cfg2) # EasyDict

print(yaml_cfg1["TRAIN"]["RPN_POSITIVE_OVERLAP"])
print(yaml_cfg2.TRAIN.RPN_POSITIVE_OVERLAP)

 

 

二、输入的处理

初始化4个Tensor,im_data、im_info、num_boxes、gt_boxes

# initilize the tensor holder here.
  im_data = torch.FloatTensor(1)
  im_info = torch.FloatTensor(1)
  num_boxes = torch.LongTensor(1)
  gt_boxes = torch.FloatTensor(1)

  # ship to cuda
  if args.cuda > 0:
    im_data = im_data.cuda()
    im_info = im_info.cuda()
    num_boxes = num_boxes.cuda()
    gt_boxes = gt_boxes.cuda()

  # make variable
  im_data = Variable(im_data, volatile=True)
  im_info = Variable(im_info, volatile=True)
  num_boxes = Variable(num_boxes, volatile=True)
  gt_boxes = Variable(gt_boxes, volatile=True)

  

将文件夹中的所有图片放入imglist中

# Set up webcam or get image directories
  if webcam_num >= 0 :
    cap = cv2.VideoCapture(webcam_num)
    num_images = 0
  else:
    imglist = os.listdir(args.image_dir)
    num_images = len(imglist)

  

进入循环,循环次数为输入图片数量

while (num_images >= 0):
      total_tic = time.time()
      if webcam_num == -1:
        num_images -= 1

  

将其中一个图片读取进来,如果是单通道,则复制三遍至三通道,并转换成BGR格式

else:
        im_file = os.path.join(args.image_dir, imglist[num_images])
        # im = cv2.imread(im_file)
        im_in = np.array(imread(im_file))
      if len(im_in.shape) == 2:
        im_in = im_in[:,:,np.newaxis]
        im_in = np.concatenate((im_in,im_in,im_in), axis=2)
      # rgb -> bgr
      im = im_in[:,:,::-1]

  

进入 blobs, im_scales = _get_image_blob(im) 函数

先减均值、再将图片等比例放缩至最短边为600,返回放缩后的图片和放缩比例。

def _get_image_blob(im):
  """Converts an image into a network input.
  Arguments:
    im (ndarray): a color image in BGR order
  Returns:
    blob (ndarray): a data blob holding an image pyramid
    im_scale_factors (list): list of image scales (relative to im) used
      in the image pyramid
  """
  im_orig = im.astype(np.float32, copy=True)
  im_orig -= cfg.PIXEL_MEANS

  im_shape = im_orig.shape
  im_size_min = np.min(im_shape[0:2])
  im_size_max = np.max(im_shape[0:2])

  processed_ims = []
  im_scale_factors = []

  for target_size in cfg.TEST.SCALES:
    im_scale = float(target_size) / float(im_size_min)
    # Prevent the biggest axis from being more than MAX_SIZE
    if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:
      im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
    im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
            interpolation=cv2.INTER_LINEAR)
    im_scale_factors.append(im_scale)
    processed_ims.append(im)

  # Create a blob to hold the input images
  blob = im_list_to_blob(processed_ims)

  return blob, np.array(im_scale_factors)

  

将im_info_np存入图片的分辨率和放缩比例

im_info_np = np.array([[im_blob.shape[1], im_blob.shape[2], im_scales[0]]], dtype=np.float32)

  

将NHWC转换成NCHW

im_data_pt = im_data_pt.permute(0, 3, 1, 2)

  

进入前传函数

rois, cls_prob, bbox_pred, \
      rpn_loss_cls, rpn_loss_box, \
      RCNN_loss_cls, RCNN_loss_bbox, \
      rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)

  

RCNN_base这里是vgg16的前面去掉最后一层池化层,最后输出的尺寸是原图尺寸的1//16,通道数是512

base_feat = self.RCNN_base(im_data)

  

RPN模块

rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(base_feat, im_info, gt_boxes, num_boxes)

  

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM