代码来源:
https://github.com/jwyang/faster-rcnn.pytorch
一、EasyDict 和 yaml
作者是使用EasyDict进行导入参数的
# `pip install easydict` if you don't have it from easydict import EasyDict as edict __C = edict() # Consumers can get config by: # from fast_rcnn_config import cfg cfg = __C
EasyDict,简单字典,就是将字典的索引方式由 myDict['key'] 的方式简化成 myDict.key
from easydict import EasyDict as edict myedict = edict() cfg = myedict myedict.a = 100 myedict.TRAIN = edict() myedict.TRAIN.LEARNING_RATE = 0.001 myedict.TRAIN.MOMENTUM = 0.9 print(cfg.a) print(cfg.TRAIN.LEARNING_RATE)
注意这里的cfg和mydict是赋值引用,即cfg和myedict为同一个对象,myedict改变,cfg则会产生同样的改变。
如果字典本身的所有key都是string,则可以直接将其转化成EasyDict,myEasyDict = EasyDict(myDict)
作者把预设的参数也用yaml格式以文件格式存储了
EXP_DIR: vgg16 TRAIN: HAS_RPN: True BBOX_NORMALIZE_TARGETS_PRECOMPUTED: True RPN_POSITIVE_OVERLAP: 0.7 RPN_BATCHSIZE: 256 PROPOSAL_METHOD: gt BG_THRESH_LO: 0.0 BATCH_SIZE: 256 LEARNING_RATE: 0.01 TEST: HAS_RPN: True POOLING_MODE: align CROP_RESIZE_WITH_MAX_POOL: False
yaml文件格式要求如下:
- 区分大小写;
- 使用缩进表示层级关系;
- 使用空格键缩进,而非Tab键缩进
- 缩进的空格数目不固定,只需要相同层级的元素左侧对齐;
- 文件中的字符串不需要使用引号标注,但若字符串包含有特殊字符则需用引号标注;
- 注释标识为#
将yaml格式的文件导入,可以这样做:
import yaml from easydict import EasyDict as edict with open("vgg16.yml", 'r') as f: yaml_cfg1 = yaml.load(f, Loader=yaml.FullLoader) with open("vgg16.yml", 'r') as f: yaml_cfg2 = edict(yaml.load(f, Loader=yaml.FullLoader)) print(yaml_cfg1) # 字典 print(yaml_cfg2) # EasyDict print(yaml_cfg1["TRAIN"]["RPN_POSITIVE_OVERLAP"]) print(yaml_cfg2.TRAIN.RPN_POSITIVE_OVERLAP)
二、输入的处理
初始化4个Tensor,im_data、im_info、num_boxes、gt_boxes
# initilize the tensor holder here. im_data = torch.FloatTensor(1) im_info = torch.FloatTensor(1) num_boxes = torch.LongTensor(1) gt_boxes = torch.FloatTensor(1) # ship to cuda if args.cuda > 0: im_data = im_data.cuda() im_info = im_info.cuda() num_boxes = num_boxes.cuda() gt_boxes = gt_boxes.cuda() # make variable im_data = Variable(im_data, volatile=True) im_info = Variable(im_info, volatile=True) num_boxes = Variable(num_boxes, volatile=True) gt_boxes = Variable(gt_boxes, volatile=True)
将文件夹中的所有图片放入imglist中
# Set up webcam or get image directories if webcam_num >= 0 : cap = cv2.VideoCapture(webcam_num) num_images = 0 else: imglist = os.listdir(args.image_dir) num_images = len(imglist)
进入循环,循环次数为输入图片数量
while (num_images >= 0): total_tic = time.time() if webcam_num == -1: num_images -= 1
将其中一个图片读取进来,如果是单通道,则复制三遍至三通道,并转换成BGR格式
else: im_file = os.path.join(args.image_dir, imglist[num_images]) # im = cv2.imread(im_file) im_in = np.array(imread(im_file)) if len(im_in.shape) == 2: im_in = im_in[:,:,np.newaxis] im_in = np.concatenate((im_in,im_in,im_in), axis=2) # rgb -> bgr im = im_in[:,:,::-1]
进入 blobs, im_scales = _get_image_blob(im) 函数
先减均值、再将图片等比例放缩至最短边为600,返回放缩后的图片和放缩比例。
def _get_image_blob(im): """Converts an image into a network input. Arguments: im (ndarray): a color image in BGR order Returns: blob (ndarray): a data blob holding an image pyramid im_scale_factors (list): list of image scales (relative to im) used in the image pyramid """ im_orig = im.astype(np.float32, copy=True) im_orig -= cfg.PIXEL_MEANS im_shape = im_orig.shape im_size_min = np.min(im_shape[0:2]) im_size_max = np.max(im_shape[0:2]) processed_ims = [] im_scale_factors = [] for target_size in cfg.TEST.SCALES: im_scale = float(target_size) / float(im_size_min) # Prevent the biggest axis from being more than MAX_SIZE if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE: im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max) im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR) im_scale_factors.append(im_scale) processed_ims.append(im) # Create a blob to hold the input images blob = im_list_to_blob(processed_ims) return blob, np.array(im_scale_factors)
将im_info_np存入图片的分辨率和放缩比例
im_info_np = np.array([[im_blob.shape[1], im_blob.shape[2], im_scales[0]]], dtype=np.float32)
将NHWC转换成NCHW
im_data_pt = im_data_pt.permute(0, 3, 1, 2)
进入前传函数
rois, cls_prob, bbox_pred, \ rpn_loss_cls, rpn_loss_box, \ RCNN_loss_cls, RCNN_loss_bbox, \ rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)
RCNN_base这里是vgg16的前面去掉最后一层池化层,最后输出的尺寸是原图尺寸的1//16,通道数是512
base_feat = self.RCNN_base(im_data)
RPN模块
rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(base_feat, im_info, gt_boxes, num_boxes)