學習Faster R-CNN代碼demo(一)


注釋Yang Jianwei 的Faster R-CNN代碼(PyTorch)

jwyang’s github: https://github.com/jwyang/faster-rcnn.pytorch

文件demo.py 

這個文件是自己下載好訓練好的模型后可執行

下面是對代碼的詳細注釋(直接在代碼上注釋):

1.有關導入的庫

 1 # --------------------------------------------------------
 2 # Tensorflow Faster R-CNN
 3 # Licensed under The MIT License [see LICENSE for details]
 4 # Written by Jiasen Lu, Jianwei Yang, based on code from Ross Girshick
 5 # --------------------------------------------------------
 6 
 7 #Python提供了__future__模塊,把下一個新版本的特性導入到當前版本
 8 from __future__ import absolute_import#加入絕對引入這個新特性  引入系統的標准
 9 
10 #導入python未來支持的語言特征division(精確除法),當我們沒有在程序中導入該特征時,"/"操作符執行的是截斷除法(Truncating Division),
11 #當我們導入精確除法之后,"/"執行的是精確除法
12 from __future__ import division
13 
14 #即使在python2.X,使用print就得像python3.X那樣加括號使用
15 from __future__ import print_function
16 
17 #_init_paths是指lib/model/_init_paths.py ?
18 import _init_paths 
19 import os #通過os模塊調用系統命令
20 import sys #sys 模塊包括了一組非常實用的服務,內含很多函數方法和變量
21 #numpy用來處理圖片數據(多維數組), 尤其是numpy的broadcasting特性, 使得不同維度的數組可以一起操作(加,減,乘, 除, 等).
22 import numpy as np
23 import argparse #為py文件封裝好可以選擇的參數
24 import pprint #提供了打印出任何python數據結構類和方法。
25 import pdb #使用 Pdb調試 Python程序
26 import time
27 import cv2
28 import torch
29 #介紹autograde  https://www.jianshu.com/p/cbce2dd60120
30 from torch.autograd import Variable#自動微分 vairable是tensor的一個外包裝
31 import torch.nn as nn
32 import torch.optim as optim
33 
34 #為了方便加載以上五種數據庫的數據,pytorch團隊幫我們寫了一個torchvision包。
35 #使用torchvision就可以輕松實現數據的加載和預處理。
36 import torchvision.transforms as transforms# transforms用於數據預處理
37 import torchvision.datasets as dset
38 
39 #scipy.misc 下的圖像處理
40 #imread():返回的是 numpy.ndarray 也即 numpy 下的多維數組對象;
41 from scipy.misc import imread
42 
43 from roi_data_layer.roidb import combined_roidb
44 from roi_data_layer.roibatchLoader import roibatchLoader
45 #demo.py運行過程中的配置基本上都在config.py了. 后續的代碼流程中會用到這些配置值. 
46 from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
47 from model.rpn.bbox_transform import clip_boxes
48 from model.nms.nms_wrapper import nms
49 from model.rpn.bbox_transform import bbox_transform_inv
50 from model.utils.net_utils import save_net, load_net, vis_detections
51 from model.utils.blob import im_list_to_blob
52 from model.faster_rcnn.vgg16 import vgg16
53 from model.faster_rcnn.resnet import resnet
54 import pdb
55 
56 try:
57     xrange          # Python 2
58 except NameError:
59     xrange = range  # Python 3

2.解析參數 parse_args()

 1 def parse_args():
 2   """
 3   Parse input arguments
 4   """
 5   parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
 6   parser.add_argument('--dataset', dest='dataset',#指代你跑得數據集名稱,例如pascal-voc
 7                       help='training dataset',
 8                       default='pascal_voc', type=str)
 9   parser.add_argument('--cfg', dest='cfg_file',#配置文件
10                       help='optional config file',
11                       default='cfgs/vgg16.yml', type=str)
12   parser.add_argument('--net', dest='net',#backbone網絡類型
13                       help='vgg16, res50, res101, res152',
14                       default='res101', type=str)
15   parser.add_argument('--set', dest='set_cfgs',#設置
16                       help='set config keys', default=None,
17                       nargs=argparse.REMAINDER)
18   parser.add_argument('--load_dir', dest='load_dir',#模型目錄
19                       help='directory to load models',
20                       default="/srv/share/jyang375/models")
21   parser.add_argument('--image_dir', dest='image_dir',#圖片目錄
22                       help='directory to load images for demo',
23                       default="images")
24   parser.add_argument('--cuda', dest='cuda',#是否用GPU
25                       help='whether use CUDA',
26                       action='store_true')
27   parser.add_argument('--mGPUs', dest='mGPUs',#是不是多GPU
28                       help='whether use multiple GPUs',
29                       action='store_true')
30     #class-agnostic 方式只回歸2類bounding box,即前景和背景,
31     #結合每個box在classification 網絡中對應着所有類別的得分,以及檢測閾值條件,就可以得到圖片中所有類別的檢測結果
32   parser.add_argument('--cag', dest='class_agnostic',#是否class_agnostic回歸
33                       help='whether perform class_agnostic bbox regression',
34                       action='store_true')
35   parser.add_argument('--parallel_type', dest='parallel_type',#模型的哪一部分並行
36                       help='which part of model to parallel, 0: all, 1: model before roi pooling',
37                       default=0, type=int)
38   parser.add_argument('--checksession', dest='checksession',
39                       help='checksession to load model',
40                       default=1, type=int)
41   parser.add_argument('--checkepoch', dest='checkepoch',
42                       help='checkepoch to load network',
43                       default=1, type=int)
44   #--checkpoint  a way to save the current state of your experiment so that you can pick up from where you left off.
45   parser.add_argument('--checkpoint', dest='checkpoint',#跟保存模型有關
46                       help='checkpoint to load network',
47                       default=10021, type=int)
48   parser.add_argument('--bs', dest='batch_size',#批大小
49                       help='batch_size',
50                       default=1, type=int)
51   parser.add_argument('--vis', dest='vis',
52                       help='visualization mode',#可視化模型
53                       action='store_true')
54   parser.add_argument('--webcam_num', dest='webcam_num',#好像就是網絡哦攝像機
55                       help='webcam ID number',
56                       default=-1, type=int)
57 
58   #parse_args()是將之前add_argument()定義的參數進行賦值,並返回相關的namespace。
59   args = parser.parse_args()
60   return args
61 
62 lr = cfg.TRAIN.LEARNING_RATE#學習率
63 momentum = cfg.TRAIN.MOMENTUM#動量
64 weight_decay = cfg.TRAIN.WEIGHT_DECAY#權重衰減

函數 _get_image_blob(im)

 1 def _get_image_blob(im):
 2 #這個函數其實就是讀取圖片,然后做尺寸變換,然后存儲成矩陣的形式
 3   """Converts an image into a network input.
 4   Arguments:
 5     im (ndarray): a color image in BGR order
 6   Returns:
 7     blob (ndarray): a data blob holding an image pyramid 
 8     im_scale_factors (list): list of image scales (relative to im) used
 9       in the image pyramid
10   """
11   #Numpy中 astype:轉換數組的數據類型。
12   im_orig = im.astype(np.float32, copy=True)
13   #而pixel mean的話,其實是把訓練集里面所有圖片的所有R通道像素,求了均值,G,B通道類似
14   im_orig -= cfg.PIXEL_MEANS
15 
16   im_shape = im_orig.shape
17   #所有元素中的min or max
18   im_size_min = np.min(im_shape[0:2])#后面有可能有其他維度,這里留兩維
19   im_size_max = np.max(im_shape[0:2])
20 
21   processed_ims = []
22   im_scale_factors = []
23 
24   for target_size in cfg.TEST.SCALES:#遍歷cfg.TEST.SCALES這個元組或列表中的值
25     im_scale = float(target_size) / float(im_size_min)#測試的尺度除以圖像最小長度(寬高的最小值)
26     # Prevent the biggest axis from being more than MAX_SIZE
27     #防止最大值超過MAX_SIZE,round函數四舍五入
28     if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:
29       im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
30     #調整im_orig大小
31     im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
32             interpolation=cv2.INTER_LINEAR)
33     #保存尺度值
34     im_scale_factors.append(im_scale)
35     #保存調整后的圖像
36     processed_ims.append(im)
37 
38   # Create a blob to hold the input images
39   #創建一個blob來保存輸入圖像
40   #這個函數出自這里 from model.utils.blob import im_list_to_blob
41   blob = im_list_to_blob(processed_ims)#processed_ims是調整后的圖像值
42 
43   return blob, np.array(im_scale_factors)

4.主函數 if name == ‘main’:

  1 if __name__ == '__main__':
  2 
  3   args = parse_args()#這就是上面定義的那個函數
  4 
  5   print('Called with args:')
  6   print(args)
  7 
  8   if args.cfg_file is not None: #配置文件
  9     #model.utils.config 該文件中函數 """Load a config file and merge it into the default options."""
 10     cfg_from_file(args.cfg_file) #
 11   if args.set_cfgs is not None: #設置配置
 12     #model.utils.config文件中"""Set config keys via list (e.g., from command line)."""
 13     cfg_from_list(args.set_cfgs) #
 14     
 15   #Use GPU implementation of non-maximum suppression
 16   #解析參數是不是用GPU
 17   cfg.USE_GPU_NMS = args.cuda 
 18 
 19   print('Using config:')
 20   pprint.pprint(cfg)
 21   
 22   #設置隨機數種子
 23   #每次運行代碼時設置相同的seed,則每次生成的隨機數也相同,
 24   #如果不設置seed,則每次生成的隨機數都會不一樣
 25   np.random.seed(cfg.RNG_SEED)
 26 
 27   # train set
 28   # -- Note: Use validation set and disable the flipped to enable faster loading.
 29   
 30   #load_dir 模型目錄   args.net 網絡   args.dataset 數據集
 31   input_dir = args.load_dir + "/" + args.net + "/" + args.dataset
 32   if not os.path.exists(input_dir):
 33     #當程序出現錯誤,python會自動引發異常,也可以通過raise顯示地引發異常。一旦執行了raise語句,raise后面的語句將不能執行。
 34     raise Exception('There is no input directory for loading network from ' + input_dir)
 35   
 36   #這里的三個check參數,是定義了訓好的檢測模型名稱,例如訓好的名稱為faster_rcnn_1_20_10021,
 37   #代表了checksession = 1,checkepoch = 20, checkpoint = 10021,這樣才可以讀到模型“faster_rcnn_1_20_10021”
 38   load_name = os.path.join(input_dir,
 39     'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))
 40 
 41   #PASCAL類別 1類背景 + 20類Object
 42   #array和asarray都可以將結構數據轉化為ndarray,但是主要區別就是當數據源是ndarray時,
 43   #array仍然會copy出一個副本,占用新的內存,但asarray不會。
 44   pascal_classes = np.asarray(['__background__',
 45                        'aeroplane', 'bicycle', 'bird', 'boat',
 46                        'bottle', 'bus', 'car', 'cat', 'chair',
 47                        'cow', 'diningtable', 'dog', 'horse',
 48                        'motorbike', 'person', 'pottedplant',
 49                        'sheep', 'sofa', 'train', 'tvmonitor'])
 50 
 51   # initilize the network here.
 52   #class-agnostic 方式只回歸2類bounding box,即前景和背景
 53   if args.net == 'vgg16':
 54     fasterRCNN = vgg16(pascal_classes, pretrained=False, class_agnostic=args.class_agnostic)
 55   elif args.net == 'res101':
 56     fasterRCNN = resnet(pascal_classes, 101, pretrained=False, class_agnostic=args.class_agnostic)
 57   elif args.net == 'res50':
 58     fasterRCNN = resnet(pascal_classes, 50, pretrained=False, class_agnostic=args.class_agnostic)
 59   elif args.net == 'res152':
 60     fasterRCNN = resnet(pascal_classes, 152, pretrained=False, class_agnostic=args.class_agnostic)
 61   else:
 62     print("network is not defined")
 63     #到了pdb.set_trace()那就會定下來,就可以看到調試的提示符(Pdb)了
 64     pdb.set_trace()
 65 
 66   fasterRCNN.create_architecture()#model.faster_rcnn.faster_rcnn.py 初始化模型 初始化權重
 67 
 68   print("load checkpoint %s" % (load_name))#模型路徑
 69   if args.cuda > 0:#GPU
 70     checkpoint = torch.load(load_name)
 71   else:#CPU?
 72     ################################################################
 73     #在cpu上加載預先訓練好的GPU模型,強制所有GPU張量在CPU中的方式:
 74     checkpoint = torch.load(load_name, map_location=(lambda storage, loc: storage))
 75   
 76   #the_model = TheModelClass(*args, **kwargs)
 77   #the_model.load_state_dict(torch.load(PATH))###恢復恢復
 78   fasterRCNN.load_state_dict(checkpoint['model'])#恢復模型
 79   if 'pooling_mode' in checkpoint.keys():
 80     cfg.POOLING_MODE = checkpoint['pooling_mode']#pooling方式
 81 
 82 
 83   print('load model successfully!')
 84 
 85   # pdb.set_trace()
 86 
 87   print("load checkpoint %s" % (load_name))
 88 
 89   # initilize the tensor holder here.
 90   #新建一些 一維Tensor
 91   im_data = torch.FloatTensor(1)
 92   im_info = torch.FloatTensor(1)
 93   num_boxes = torch.LongTensor(1)
 94   gt_boxes = torch.FloatTensor(1)
 95 
 96   # ship to cuda
 97   if args.cuda > 0:#如果用GPU,張量放到GPU上
 98     im_data = im_data.cuda()
 99     im_info = im_info.cuda()
100     num_boxes = num_boxes.cuda()
101     gt_boxes = gt_boxes.cuda()
102 
103   # make variable
104   #ariable的volatile屬性默認為False,如果某一個variable的volatile屬性被設為True,
105   #那么所有依賴它的節點volatile屬性都為True。
106   #volatile屬性為True的節點不會求導,volatile的優先級比requires_grad高。
107   im_data = Variable(im_data, volatile=True)
108   im_info = Variable(im_info, volatile=True)
109   num_boxes = Variable(num_boxes, volatile=True)
110   gt_boxes = Variable(gt_boxes, volatile=True)
111 
112   if args.cuda > 0:
113     cfg.CUDA = True
114 
115   if args.cuda > 0:
116     fasterRCNN.cuda()
117 
118   #model.eval(),讓model變成測試模式,
119   #對dropout和batch normalization的操作在訓練和測試的時候是不一樣的
120   #pytorch會自動把BN和DropOut固定住,不會取平均,而是用訓練好的值
121   fasterRCNN.eval()
122 
123   #通過time()函數可以獲取當前的時間
124   start = time.time()
125   max_per_image = 100
126   thresh = 0.05
127   vis = True
128 
129   webcam_num = args.webcam_num
130   # Set up webcam or get image directories
131   if webcam_num >= 0 :#應該就是判斷要不要自己用電腦錄視頻
132     #cap = cv2.VideoCapture(0) 打開筆記本的內置攝像頭。
133     #cap = cv2.VideoCapture('D:\output.avi') 打開視頻文件
134     cap = cv2.VideoCapture(webcam_num)
135     num_images = 0
136   else:#如果不用電腦錄視頻,那么就讀取image路徑下的圖片
137     #os.listdir() 方法用於返回指定的文件夾包含的文件或文件夾的名字的列表
138     #這個列表以字母順序
139     imglist = os.listdir(args.image_dir)
140     num_images = len(imglist)#有多少張圖片
141 
142   print('Loaded Photo: {} images.'.format(num_images))
143 
144 
145   while (num_images >= 0):
146       total_tic = time.time()#當前時間
147       if webcam_num == -1:#如果不用攝像頭
148         num_images -= 1
149 
150       # Get image from the webcam
151       #從電腦攝像頭讀取圖片
152       if webcam_num >= 0:
153         if not cap.isOpened():#攝像頭開啟失敗
154           raise RuntimeError("Webcam could not open. Please check connection.")
155         
156         #ret 為True 或者False,代表有沒有讀取到圖片
157         #frame表示截取到一幀的圖片
158         ret, frame = cap.read()
159         
160         #攝像頭截取到一幀的圖片 存儲為numpy數組
161         im_in = np.array(frame)
162       # Load the demo image
163       else:
164         #圖片路徑
165         im_file = os.path.join(args.image_dir, imglist[num_images])
166         # im = cv2.imread(im_file)
167         #讀取的的圖片 存儲為numpy數組
168         im_in = np.array(imread(im_file))
169       if len(im_in.shape) == 2:
170         #np.newaxis的作用就是在這一位置增加一個一維,
171         #這一位置指的是np.newaxis所在的位置,比較抽象,需要配合例子理解。
172         #####################example
173         #x1 = np.array([1, 2, 3, 4, 5])
174         # the shape of x1 is (5,)
175         #x1_new = x1[:, np.newaxis]
176         # now, the shape of x1_new is (5, 1)
177         # array([[1],
178         #        [2],
179         #        [3],
180         #        [4],
181         #        [5]])
182         #x1_new = x1[np.newaxis,:]
183         # now, the shape of x1_new is (1, 5)
184         # array([[1, 2, 3, 4, 5]])
185         #####################
186         im_in = im_in[:,:,np.newaxis]#變為二維?
187         
188         #數組拼接
189         #若axis=0,則要求除了a.shape[0]和b.shape[0]可以不等之外,其它維度必須相等
190         #若axis=0,則要求除了a.shape[0]和b.shape[0]可以不等之外,其它維度必須相等
191         #axis>=2 的情況以此類推,axis的值必須小於數組的維度
192         im_in = np.concatenate((im_in,im_in,im_in), axis=2)
193         
194       # rgb -> bgr
195       #line[:-1]其實就是去除了這行文本的最后一個字符(換行符)后剩下的部分。
196       #line[::-1]字符串反過來 line = "abcde" line[::-1] 結果為:'edcba'
197       im = im_in[:,:,::-1]#RGB->BGR
198 
199       blobs, im_scales = _get_image_blob(im)#圖片變換 該文件上面定義的函數,返回處理后的值 和尺度
200       assert len(im_scales) == 1, "Only single-image batch implemented"
201       im_blob = blobs#處理后的值
202       #圖像信息,長、寬、尺度
203       im_info_np = np.array([[im_blob.shape[1], im_blob.shape[2], im_scales[0]]], dtype=np.float32)
204 
205       #從numpy變為Tensor
206       im_data_pt = torch.from_numpy(im_blob)
207       #permute 將tensor的維度換位。
208       #參數:參數是一系列的整數,代表原來張量的維度。比如三維就有0,1,2這些dimension。
209       #把索引為3的張量位置給提到前面了,例如128 128 3的圖片變為 3 128 128
210       im_data_pt = im_data_pt.permute(0, 3, 1, 2)
211       #圖像信息也變為tensor
212       im_info_pt = torch.from_numpy(im_info_np)
213 
214       #將tensor的大小調整為指定的大小。
215       #如果元素個數比當前的內存大小大,就將底層存儲大小調整為與新元素數目一致的大小。
216       im_data.data.resize_(im_data_pt.size()).copy_(im_data_pt)
217       im_info.data.resize_(im_info_pt.size()).copy_(im_info_pt)
218       gt_boxes.data.resize_(1, 1, 5).zero_()
219       num_boxes.data.resize_(1).zero_()
220 
221       # pdb.set_trace()
222       det_tic = time.time()#當前時間
223 
224       #參數帶入模型
225       #rois: 興趣區域,怎么表示???????????
226         # rois blob: holds R regions of interest, each is a 5-tuple
227         # (n, x1, y1, x2, y2) specifying an image batch index n and a
228         # rectangle (x1, y1, x2, y2)
229         # top[0].reshape(1, 5)
230       #cls_prob: softmax得到的概率值
231       #bbox_pred: 偏移
232       #rpn_loss_cls分類損失,計算softmax的損失,輸入labels和cls layer的18個輸出(中間reshape了一下),輸出損失函數的具體值
233       #rpn_loss_box 計算的框回歸損失函數具體的值
234       rois, cls_prob, bbox_pred, \
235       rpn_loss_cls, rpn_loss_box, \
236       RCNN_loss_cls, RCNN_loss_bbox, \
237       rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)
238 
239       scores = cls_prob.data#分類概率值
240       ###################################################
241       #boxes包含框的坐標
242       #各維度表示什么??????????
243       boxes = rois.data[:, :, 1:5]#?????????????????????
244 
245       if cfg.TEST.BBOX_REG:#Train bounding-box regressors TRUE or FALSE
246           # Apply bounding-box regression deltas
247           box_deltas = bbox_pred.data#偏移值
248           if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
249           # Optionally normalize targets by a precomputed mean and stdev
250             if args.class_agnostic:
251                 if args.cuda > 0:
252                     #box_deltas.view改變維度
253                     box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
254                                + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
255                 else:
256                     box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS) \
257                                + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS)
258 
259                 box_deltas = box_deltas.view(1, -1, 4)
260             else:
261                 if args.cuda > 0:
262                     box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
263                                + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
264                 else:
265                     box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS) \
266                                + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS)
267                 box_deltas = box_deltas.view(1, -1, 4 * len(pascal_classes))
268 
269          #model.rpn.bbox_transform 根據anchor和偏移量計算proposals
270          #最后返回的是左上和右下頂點的坐標[x1,y1,x2,y2]。
271          pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
272          #model.rpn.bbox_transform 
273          #將改變坐標信息后超過圖像邊界的框的邊框裁剪一下,使之在圖像邊界之內
274           pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)
275       else:
276           # Simply repeat the boxes, once for each class
277           #Numpy的 tile() 函數,就是將原矩陣橫向、縱向地復制,這里是橫向
278           pred_boxes = np.tile(boxes, (1, scores.shape[1]))
279 
280       pred_boxes /= im_scales[0]
281 
282       #squeeze 函數:從數組的形狀中刪除單維度條目,即把shape中為1的維度去掉
283       scores = scores.squeeze()
284       pred_boxes = pred_boxes.squeeze()
285       det_toc = time.time()#當前時間
286       detect_time = det_toc - det_tic#detect_time
287       misc_tic = time.time()
288       if vis:
289           im2show = np.copy(im)
290       for j in xrange(1, len(pascal_classes)):#所有類別
291           #torch.nonzero
292           #返回一個包含輸入input中非零元素索引的張量,輸出張量中的每行包含輸入中非零元素的索引
293           #若輸入input有n維,則輸出的索引張量output形狀為z * n, 這里z是輸入張量input中所有非零元素的個數
294           inds = torch.nonzero(scores[:,j]>thresh).view(-1)#參數中的-1就代表這個位置由其他位置的數字來推斷
295           # if there is det
296           #torch.numel() 返回一個tensor變量內所有元素個數,可以理解為矩陣內元素的個數
297           if inds.numel() > 0:
298             cls_scores = scores[:,j][inds]
299             #torch.sort(input, dim=None, descending=False, out=None)有true,則表示降序,默認升序
300             _, order = torch.sort(cls_scores, 0, True)#沿第0列降序
301             if args.class_agnostic:#兩類
302               cls_boxes = pred_boxes[inds, :]
303             else:
304               cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]#why???
305             
306             #按行連接起來,torch.unsqueeze()這個函數主要是對數據維度進行擴充
307             cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)
308             # cls_dets = torch.cat((cls_boxes, cls_scores), 1)
309             cls_dets = cls_dets[order]
310             #model.nms.nms_wrapper
311             keep = nms(cls_dets, cfg.TEST.NMS, force_cpu=not cfg.USE_GPU_NMS)
312             cls_dets = cls_dets[keep.view(-1).long()]
313             if vis:
314               #model.utils.net_utils
315               im2show = vis_detections(im2show, pascal_classes[j], cls_dets.cpu().numpy(), 0.5)
316 
317       misc_toc = time.time()
318       nms_time = misc_toc - misc_tic
319 
320       if webcam_num == -1:
321           #當我們使用print(obj)在console上打印對象的時候,實質上調用的是sys.stdout.write(obj+'\n')
322           sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s   \r' \
323                            .format(num_images + 1, len(imglist), detect_time, nms_time))
324           sys.stdout.flush()
325 
326       if vis and webcam_num == -1:
327           # cv2.imshow('test', im2show)
328           # cv2.waitKey(0)
329           result_path = os.path.join(args.image_dir, imglist[num_images][:-4] + "_det.jpg")
330           cv2.imwrite(result_path, im2show)
331       else:
332           im2showRGB = cv2.cvtColor(im2show, cv2.COLOR_BGR2RGB)
333           cv2.imshow("frame", im2showRGB)
334           total_toc = time.time()
335           total_time = total_toc - total_tic
336           frame_rate = 1 / total_time
337           print('Frame rate:', frame_rate)
338           if cv2.waitKey(1) & 0xFF == ord('q'):
339               break
340   if webcam_num >= 0:
341       cap.release()
342       cv2.destroyAllWindows()

REF:YF-Zhang


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM