最近工作里需要用到tensorflow的pretrained-model去做retrain. 記錄一下.
為什么可以用pretrained-model去做retrain
這個就要引出CNN的本質了.CNN的本質就是求出合適的卷積核,提取出合理的底層特征.進而為不同的特征賦以權重.從而表達圖像.
通俗點講,比如有一張貓的圖片,你怎么判斷是貓不是狗?你可能會看到圖里有貓的頭,貓的爪子,貓的尾巴. 頭/爪子/尾巴 就是CNN中比較靠前的層所提取出來的特征,我們稱之為高級特征,這時候的特征我們人類還是能理解的. 繼續對這些頭/爪子/尾巴繼續做特征提取,...,最終得到的特征已經非常細節非常抽象了,可能是一個點,一條線等等. 最終我們的image=這些低級特征乘以不同權重,求和.
假設現在你有一個基於公開數據集的trained-model.這個數據集里沒有你想識別的圖片,比如紅綠燈吧. 但是,沒關系!!,雖然你之前的模型不認識紅綠燈,但是它也抽象出來了很多底層的抽象的細節特征啊,點啊,線啊之類的. 我們依然可以使用這些特征去表示紅綠燈圖片,只是每個特征的權重要改變而已! 這就是所謂的增強學習.
tensorflow里存儲"很多底層的抽象的細節特征啊,點啊,線啊之類的"文件,稱之為module.更多詳細的見https://www.tensorflow.org/hub/tutorials/image_retraining
環境准備
- conda activate venv_python3.6
- pip install "tensorflow>=1.7.0"
- pip install tensorflow-hub
數據准備
- cd ~
- curl -LO http://download.tensorflow.org/example_images/flower_photos.tgz
- tar xzf flower_photos.tgz
示例代碼下載
- mkdir ~/example_code
- cd ~/example_code
- curl -LO https://github.com/tensorflow/hub/raw/master/examples/image_retraining/retrain.py
重訓練
- python retrain.py --image_dir ~/flower_photos
訓練相關的文件模型等存儲於/tmp
- /tmp/bottleneck 可以理解為每一個圖片的feature map 存儲的是新的class的image的抽象特征
- /tmp/output_graph.pb 新的模型
- /tmp/output_labels.txt 新識別出的label
bottleneck可以理解為image feature vector.可以理解為各種抽象的特征,點啊直線啊折線啊,利用這些特征,模型可以去做分類.
The script can take thirty minutes or more to complete, depending on the speed of your machine. The first phase analyzes all the images on disk and calculates and caches the bottleneck values for each of them. 'Bottleneck' is an informal term we often use for the layer just before the final output layer that actually does the classification. (TensorFlow Hub calls this an "image feature vector".) This penultimate layer has been trained to output a set of values that's good enough for the classifier to use to distinguish between all the classes it's been asked to recognize. That means it has to be a meaningful and compact summary of the images, since it has to contain enough information for the classifier to make a good choice in a very small set of values. The reason our final layer retraining can work on new classes is that it turns out the kind of information needed to distinguish between all the 1,000 classes in ImageNet is often also useful to distinguish between new kinds of objects.
- training accuracy 訓練集精度
- validation accuracy 驗證集精度
- Cross entropy 交叉熵
Cross entropy is a loss function which gives a glimpse into how well the learning process is progressing
整體而言,cross entropy應該是不斷減小的,中間可能會有小的波動
train.py
python retrain.py \
--image_dir ~/flower_photos \
--tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2
- 會從url + '?tf-hub-format=compressed'下載module包.默認會下載到/tmp/tfhub_modules
tar -xvf ../module.tar ./
./
./saved_model.pb
./variables/
./variables/variables.index
./variables/variables.data-00000-of-00001
./assets/
./tfhub_module.pb
這里面就包含了抽象的底層特征.
ssd module下載
https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1
數據集結構
每個目錄下是相應類別的jpg文件
數據集的搜集應當注意的幾點問題
The first place to start is by looking at the images you've gathered, since the most common issues we see with training come from the data that's being fed in.
For training to work well, you should gather at least a hundred photos of each kind of object you want to recognize. The more you can gather, the better the accuracy of your trained model is likely to be. You also need to make sure that the photos are a good representation of what your application will actually encounter. For example, if you take all your photos indoors against a blank wall and your users are trying to recognize objects outdoors, you probably won't see good results when you deploy.
Another pitfall to avoid is that the learning process will pick up on anything that the labeled images have in common with each other, and if you're not careful that might be something that's not useful. For example if you photograph one kind of object in a blue room, and another in a green one, then the model will end up basing its prediction on the background color, not the features of the object you actually care about. To avoid this, try to take pictures in as wide a variety of situations as you can, at different times, and with different devices.
You may also want to think about the categories you use. It might be worth splitting big categories that cover a lot of different physical forms into smaller ones that are more visually distinct. For example instead of 'vehicle' you might use 'car', 'motorbike', and 'truck'. It's also worth thinking about whether you have a 'closed world' or an 'open world' problem. In a closed world, the only things you'll ever be asked to categorize are the classes of object you know about. This might apply to a plant recognition app where you know the user is likely to be taking a picture of a flower, so all you have to do is decide which species. By contrast a roaming robot might see all sorts of different things through its camera as it wanders around the world. In that case you'd want the classifier to report if it wasn't sure what it was seeing. This can be hard to do well, but often if you collect a large number of typical 'background' photos with no relevant objects in them, you can add them to an extra 'unknown' class in your image folders.
It's also worth checking to make sure that all of your images are labeled correctly. Often user-generated tags are unreliable for our purposes. For example: pictures tagged #daisy might also include people and characters named Daisy. If you go through your images and weed out any mistakes it can do wonders for your overall accuracy.
如何使用本地model做retrain
這一步還沒成功,因為我的需求比較特殊,我需要在jetson nano上跑模型,而tensorrt目前還是有Bug的,不是什么model都能推理,有的model里的算子不支持.而從tensorflow的官網download的ssd model的module,做retrain后得到的model無法在jetson nano上推理,
目前我需要ssd_inception_v2_coco_2017_11_17這個model對應的module,很不幸,並沒有,只能自己寫代碼去做轉換,使用了官方的create_module_spec_from_saved_model api還是有問題
與此問題相關的link
https://github.com/tensorflow/hub/issues/37
https://github.com/tensorflow/hub/blob/52d5066e925d345fbd54ddf98b7cadf027b69d99/examples/image_retraining/retrain.py 對應分支
https://www.tensorflow.org/hub/creating
python retrain.py
--image_dir ~/flower_photos
--tfhub_module ./ssd_inception_v2_coco_2017_11_17
tensorflow文件含義
- .pb文件 存儲了完整的模型的結構信息,變量信息等.
- checkpoint文件 記錄模型路徑信息
cat checkpoint
model_checkpoint_path: "/tmp/_retrain_checkpoint"
all_model_checkpoint_paths: "/tmp/_retrain_checkpoint"
- .meta文件存儲了運算圖的結構
- .index文件存儲了tensor結構的信息,ensorname<-->BundleEntryProto
- .data文件存儲所有變量的值
meta file: describes the saved graph structure, includes GraphDef, SaverDef, and so on; then apply tf.train.import_meta_graph('/tmp/model.ckpt.meta'), will restore Saver and Graph.
index file: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a tensor: which of the "data" files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.
data file: it is TensorBundle collection, save the values of all variables.