本周老師給的任務:
一是將VOT15數據集(世華已傳到服務器上)上每個序列的第1,11,21,31,41幀分別運行Faster R-CNN檢測器並保存在圖片上顯示的檢測結果;
二是將這5幀的ground truth bounding box作為proposal得到其對應的檢測器分類結果(比如網絡要檢測20類物體,那包括背景就是得到21類對應的檢測分數值),並將每個序列的檢測結果分別存成一個文本文檔。
注意,使用代碼的時候,可能會有路徑錯誤,還可能是,我貼上的代碼,博客園的網站給在某些語句后加了 <br> ,調錯的時候細看!!我在后台竟然看不到<br>,但是瀏覽的時候卻有!!
第一個問題已經解決,現在整理一下思路。
先將py faster rcnn 裝好之后,測試運行dome.py能成功展示之后,再進行接下來的工作。
我的想法是,
(1)將vot2015數據集上的所有數據的分類統計出來(就是把vot2015下的子文件夾的名稱統計出來,方便之后操作),這里直接用了( http://www.cnblogs.com/flyhigh1860/p/3896111.html )的源碼進行修改
#!/usr/bin/python # -*- coding:utf8 -*- import os allFileNum = 0 def printPath(level, path): global allFileNum ''''' 打印一個目錄下的所有文件夾和文件 ''' # 所有文件夾,第一個字段是次目錄的級別 dirList = [] # 所有文件 fileList = [] # 返回一個列表,其中包含在目錄條目的名稱(google翻譯) files = os.listdir(path) # 先添加目錄級別 dirList.append(str(level)) for f in files: if (os.path.isdir(path + '/' + f)): # 排除隱藏文件夾。因為隱藏文件夾過多 if (f[0] == '.'): pass else: # 添加非隱藏文件夾 dirList.append(f) if (os.path.isfile(path + '/' + f)): # 添加文件 fileList.append(f) # 當一個標志使用,文件夾列表第一個級別不打印 i_dl = 0
#得到的文件夾名保存在 save_file.txt 中,使用python的追加操作 ‘a’ save_file = open('/home/user/Downloads/save_file.txt','a') for dl in dirList: if (i_dl == 0): i_dl = i_dl + 1 else: # 打印至控制台,不是第一個的目錄 print '-' * (int(dirList[0])), dl
#將文件名寫入save_file.txt中 save_file.write(dl) save_file.write('\n') # 打印目錄下的所有文件夾和文件,目錄級別+1 #printPath((int(dirList[0]) + 1), path + '/' + dl) for fl in fileList: # 打印文件 print '-' * (int(dirList[0])), fl # 隨便計算一下有多少個文件 allFileNum = allFileNum + 1 if __name__ == '__main__': printPath(1, '/home/user/Downloads/vot2015') print '總文件數 =', allFileNum
這里再給出save_file.txt 文件內容
soldier butterfly hand car2 sheep birds1 motocross1 marching book road graduate fish3 fernando bag wiper gymnastics2 leaves ball1 birds2 crossing soccer1 godfather nature racing traffic pedestrian2 handball2 ball2 gymnastics1 singer2 singer1 dinosaur gymnastics3 bolt1 gymnastics4 pedestrian1 helicopter singer3 matrix octopus iceskater1 fish4 sphere car1 motocross2 girl fish1 bolt2 basketball blanket bmx shaking tiger handball1 rabbit fish2 tunnel glove iceskater2 soccer2
(2)從save_file.txt 中將分來讀取出來,保存再一個list中,之后將這段代碼加到 demo.py 中使用(參考了 http://www.cnblogs.com/xuxn/archive/2011/07/27/read-a-file-with-python.html 和 http://www.cnblogs.com/mxh1099/p/5680001.html)
l = [] file = open('/home/user/Downloads/save_file.txt') while 1: line = file.readline() if line != '\n': print line.replace("\n", "")
#在list中 加入去掉換行符的文件名 l.append(line.replace("\n","")) if not line: break print l
(3)需要將文件名和要遍歷的每個文件夾下的文件名配合,同樣,這段代碼之后會用在demo.py 中
lfile = [] file = open('/home/user/Downloads/save_file.txt') while 1: line = file.readline() if line != '\n': lfile.append(line.replace("\n", "")) if not line: break im_names =['00000023.jpg','00000011.jpg','00000001.jpg'] # im_names = ['00000001.jpg', '000000011.jpg', '00000021.jpg', # '00000031.jpg', '00000041.jpg'] for litme in lfile : for im_name in im_names: im_path = str(litme) + '/' + str(im_name) print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' #print 'Demo for data/demo/{}'.format(im_name) print im_path
(4)可以對文件遍歷之后,需要將生成的圖片結果保存下來,參考了《演示如何實現Matplotlib繪圖並保存圖像但不顯示圖形的方法》(http://blog.csdn.net/rumswell/article/details/7342479) 和Python創建目錄文件夾 (http://www.cnblogs.com/monsteryang/p/6574550.html)
最后附上我修改之后的demo.py
#!/usr/bin/env python # -------------------------------------------------------- # Faster R-CNN # Copyright (c) 2015 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ross Girshick # -------------------------------------------------------- """ Demo script showing detections in sample images. See README.md for installation instructions before running. """ import _init_paths from fast_rcnn.config import cfg from fast_rcnn.test import im_detect from fast_rcnn.nms_wrapper import nms from utils.timer import Timer import matplotlib import matplotlib.pyplot as plt import numpy as np import scipy.io as sio import caffe, os, sys, cv2 import argparse #add matplotlib.use('Agg') CLASSES = ('__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') NETS = {'vgg16': ('VGG16', 'VGG16_faster_rcnn_final.caffemodel'), 'zf': ('ZF', 'ZF_faster_rcnn_final.caffemodel')} #add def mkdir(path): import os path = path.strip() path = path.rstrip("\\") isExists = os.path.exists(path) if not isExists: os.makedirs(path) print path + 'ok' return True else: print path + 'failed!' return False def vis_detections(image_name, im, class_name, dets, thresh=0.5): """Draw detected bounding boxes.""" inds = np.where(dets[:, -1] >= thresh)[0] if len(inds) == 0: return im = im[:, :, (2, 1, 0)] fig, ax = plt.subplots(figsize=(12, 12)) ax.imshow(im, aspect='equal') for i in inds: bbox = dets[i, :4] score = dets[i, -1] ax.add_patch( plt.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], fill=False, edgecolor='red', linewidth=3.5) ) ax.text(bbox[0], bbox[1] - 2, '{:s} {:.3f}'.format(class_name, score), bbox=dict(facecolor='blue', alpha=0.5), fontsize=14, color='white') ax.set_title(('{} detections with ' 'p({} | box) >= {:.1f}').format(class_name, class_name, thresh), fontsize=14) plt.axis('off') plt.tight_layout() plt.draw() #add ll = [] ll = str(image_name).split('/') print ll[0] mkdir('/home/user/tmp/' + str(ll[0])) plt.savefig('/home/user/tmp/' + str(image_name)) def demo(net, image_name): """Detect object classes in an image using pre-computed object proposals.""" # Load the demo image im_file = os.path.join(cfg.DATA_DIR, 'demo','vot2015', image_name) print("%s", im_file) im = cv2.imread(im_file) # Detect all object classes and regress object bounds timer = Timer() timer.tic() #add try except try: scores, boxes = im_detect(net, im) timer.toc() print ('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0]) # Visualize detections for each class CONF_THRESH = 0.8 NMS_THRESH = 0.3 for cls_ind, cls in enumerate(CLASSES[1:]): cls_ind += 1 # because we skipped background cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)] cls_scores = scores[:, cls_ind] dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32) keep = nms(dets, NMS_THRESH) dets = dets[keep, :] vis_detections(image_name,im, cls, dets, thresh=CONF_THRESH) except Exception: print 'Error' def parse_args(): """Parse input arguments.""" parser = argparse.ArgumentParser(description='Faster R-CNN demo') parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int) parser.add_argument('--cpu', dest='cpu_mode', help='Use CPU mode (overrides --gpu)', action='store_true') parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]', choices=NETS.keys(), default='vgg16') args = parser.parse_args() return args if __name__ == '__main__': cfg.TEST.HAS_RPN = True # Use RPN for proposals args = parse_args() prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0], 'faster_rcnn_alt_opt', 'faster_rcnn_test.pt') caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models', NETS[args.demo_net][1]) if not os.path.isfile(caffemodel): raise IOError(('{:s} not found.\nDid you run ./data/script/' 'fetch_faster_rcnn_models.sh?').format(caffemodel)) if args.cpu_mode: caffe.set_mode_cpu() else: caffe.set_mode_gpu() caffe.set_device(args.gpu_id) cfg.GPU_ID = args.gpu_id net = caffe.Net(prototxt, caffemodel, caffe.TEST) print '\n\nLoaded network {:s}'.format(caffemodel) # Warmup on a dummy image im = 128 * np.ones((300, 500, 3), dtype=np.uint8) for i in xrange(2): _, _= im_detect(net, im) # im_names = ['000456.jpg', '000542.jpg', '001150.jpg', # '001763.jpg', '004545.jpg','00000023.jpg','00000011.jpg','00000001.jpg'] # edit lfile = [] file = open('/home/user/Downloads/save_file.txt') while 1: line = file.readline() if line != '\n': lfile.append(line.replace("\n", "")) if not line: break print lfile im_names = ['00000001.jpg', '00000011.jpg', '00000021.jpg', '00000031.jpg', '00000041.jpg'] for litme in lfile : for im_name in im_names: im_path = str(litme) + '/' + str(im_name) print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' print 'Demo for data/demo/{}'.format(im_name) try: demo(net, im_path) except Exception: print 'ERROR' #plt.show()
第二個問題先看着,沒想法
在圖片上顯示每個IOU大於0.5的proposal對應的最高檢測值的類別、分數和回歸后的框,在文本文檔里則保存每個proposal對應的21個類別的檢測分數和回歸后的邊界框坐標。
對於每個類別,總會生成300個proposals,
所以,在每個proposal,都會有4個坐標
對於每個proposal,都會有一個類別值。
因為要生成每個proposal對應的21個類別的分數,就需要將分數先保存起來,再輸出
還要記錄回歸后的邊間框。
對於圖片,顯示每個IOU大於0.5的proposal對應的最高檢測值的類別、分數和回歸后的框。
也是先要將最高檢測分數對應的類別和回歸框記錄下來。