torchnet package (2)


torchnet package (2)

Dataset Iterators

盡管是用for loop語句很容易處理Dataset,但有時希望以on-the-fly manner或者在線程中讀取數據,這時候Dataset Iterator就是個好的選擇
注意,iterators是用於特殊情況的,一般情況下還是使用Dataset比較好
Iteartor 的兩個主要方法:
* run() 返回一個Lua 迭代器,也可以使用()操作符,因為iterator源碼中定義了__call事件
* exec(funcname,...) 在指定的dataset上執行funcname方法,funcname是dataset自己的方法,比如size

  • tnt.DatasetIterator(self,dataset[,perm][,filter][,transform])
    The default dataset iterator
    perm(idx), 實現shuffle功能,即對idx進行變換,更復雜的變換可以使用ShuffleDataset
    filter(sample), 閉包函數,篩選樣本是否用於迭代,返回bool值
    transform(sample),閉包函數,實現對樣本的變換,更復雜的變換可以結合TransformDataset和transform.compose等實現

  1. ldata = tnt.ListData{list=torch.range(1,10):long(),load = function(x) return {x,x+1} end
  2. dIter = tnt.DatasetIterator{dataset = ldata,filter = function(x) if x[1]<2 then return false else return true end end
  3. for v in dIter:run() 
  4. print(v) 
  5. end 
  • tnt.ParallelDatasetIterator(self[,init],closure,nthread[,perm][,filter][,transform][,ordered])
    這個才是迭代器的重點,用於以多線程方式迭代數據。

The purpose of this class is to have a zero pre-processing cose. when reading datasets on the fly from disk(not loading thenm fully in memory), or performing complex pre-processing this canbe of interest.

nthreads 指定了線程的個數
init(threadid) 閉包函數,指定了線程threadid的初始化工作,如果啥都不做可以省略
closure(threadid) 每個線程的job,返回的必須時tnt.Dataset的一個實例
perm(idx) 用於shuffle
filter(sample) 閉包函數,指定哪些樣本不用於迭代
transform(sample) 對樣本進行變換,在filter之前執行
order 線程之間數據的處理是否有序,主要是為了程序的可重現性,當order=true時,多次執行程序,順序是相同的

  1. tnt=require'torchnet' 
  2. local list=torch.Tensor{{2,2},{2,2},{2,2},{2,2}}:long() 
  3. ldata = tnt.ListDataset{list=list,load=function(x) return torch.Tensor(x[1],x[2]) end
  4. local bdata = tnt.BatchDataset{batchsize=2,dataset = tnt.TransformDataset{dataset = ldata,transform=function(x) return 2*x end}} 
  5. Padata = tnt.ParallelDatasetIterator{ 
  6. nthread = 4
  7. init = function(tid) 
  8. print ('init thread id: '.. tid) 
  9. tnt=require'torchnet' 
  10. end
  11. closure = function(tid) 
  12. print('closure of threadid: '.. tid) 
  13. return bdata 
  14. end 
  15. }  

尤其需要注意的是,closure中的所有upvalues都必須是可序列化的,最好是避免使用upvalues,並保證closure中使用的package都在init中require

tnt.Engine

在網絡訓練的過程中,都是計算前向誤差,誤差反傳,更新權重這些過程,只是模型,數據和評價函數不同而已,所以Engine給訓練過程提供了一個模板,該模板建立了model,DatasetIterator,Criterion和Meter之間的聯系

engine=tnt.Engine()包含兩個主要方法
* engine:train() 在數據集上訓練數據
* engine:test() 評估模型,可選
Engine不僅實現了訓練和評估的一般模板,還提供了許多接口,用於控制訓練過程

  • tnt.SGDEngine
    SGDEngine 模塊在train過程中使用Stochastic Gradient Descent方法訓練,模塊包含數據采樣,前向傳遞,反向傳遞,參數更新等,還有一些鈎子函數
    hooks = {
    ['onStart'] = function() end, --用於訓練開始前的設置和初始化
    ['onStartEpoch'] = function() end, -- 每一個epoch前的操作
    ['onSample'] = function() end, -- 每次采樣一個樣本之后的操作
    ['onForward'] = function() end, -- 在model:forward()之后的操作
    ['onForwardCriterion'] = function() end, -- 前向計算損失函數之后的操作
    ['onBackwardCriterion'] = function() end, -- 反向計算損失誤差之后的操作
    ['onBackward'] = function() end, -- 反向傳遞誤差之后的操作
    ['onUpdate'] = function() end, -- 權重參數更新之后的操作
    ['onEndEpoch'] = function() end, -- 每一個epoch結束時的操作
    ['onEnd'] = function() end, -- 整個訓練過程結束后的收拾現場
    }
    可以發現Engine給的hook函數還是很全面的,幾乎訓練過程的每一個節點都允許用戶制定操作,使用hook函數

  1. local engine = SGDEngine() 
  2. local meter = tnt.AverageValueMeter() 
  3. engine.hooks.onStartEpoch = function(state) meter:reset() end 

一般而言,訓練過程最少應該知道訓練模型,損失函數,數據和學習率,這里學習方法已經知道了SGD,Engine用到的數據是tnt.DatasetIterator類型的。 評估過程只需要數據和模型就可以了

外部可以通過state變量與Engine訓練過程交互
state = {
['network'] = network, --設置了model
['criterion'] = criterion, -- 設置損失函數
['iterator'] = iterator, -- 數據迭代器
['lr'] = lr, -- 學習率
['lrcriterion'] = lrcriterion, --
['maxepoch'] = maxepoch, --最大epoch數
['sample'] = {}, -- 當前采集的樣本,可以在onSample中通過該閾值查看采樣樣本
['epoch'] = 0 , -- 當前的epoch
['t'] = 0, -- 已經訓練樣本的個數
['training'] = true -- 訓練過程
}

評估時需要指定:
state = {
['netwrok'] = network
['iterator'] = iterator
['criterion'] = criterion
}

  • tnt.OptimEngine
    這個方法和SGDEngine的最大的區別在於封裝了optim中的多種優化方法。在訓練開始的時候,engine會通過getParameters獲取model的參數
    train需要附加兩個量:

    • optimMethod 優化方法,比如optim.sgd

    • config 優化方法對應的參數
      Example:

  1. local engine = tnt.OptimEngine{ 
  2. network = network, 
  3. criterion=criterion, 
  4. iterator = iterator, 
  5. optimMethod = optim.sgd, 
  6. config = { 
  7. learningRate = 0.1
  8. momentum = 0.9
  9. }, 

tnt.Meter

和Engine配合使用,用於measure the model.
幾乎所有的meters都會有3個方法:
* add() 給待統計的meter添加一個觀測值,其輸入參數一般形式為(output,value),output為model的輸出,target為真實值
* value() 獲得待統計的meter的當前值
* reset() 重新計數
Meter的使用示例:

  1. local meter = tnt.<Measure>Meter() -- <Measure> 可以選擇具體的度量 
  2. for state,event in tnt.<Optimization>Engine:train{ --定義Engine 
  3. network = network, 
  4. criterion=criterion, 
  5. iterator=iterator, 
  6. } do 
  7. if state == 'start-epoch' then  
  8. meter:reset() -- reset meter 
  9. elseif state == 'forward-criterion' then 
  10. meter:add(state.network.output,sample.target) 
  11. elseif state == 'end-epoch' then 
  12. print('value of meter:) .. meter:value()) 
  13. end 
  14. end 
  • tnt.APMeter(self)
    評估每一類的平均正確率
    APMeter的操作對象是一個的Tensor,表示N個樣本對應在K類中的值,另外可選的一個的 Tensor表示每個樣本的權重

  1. target = torch.Tensor{ 
  2. {0,0,0,1},{0,0,1,0},{0,1,0,0},{1,0,0,0},{1,0,0,0}} 
  3. apm = tnt.APMeter() 
  4. for i=1,5 do 
  5. apm:add{output=torch.rand(1,4),target=target[i]:size(1,4)} -- 注意N*K的Tensor 
  6. end 
  7. print(apm:value()) 
  • tnt.AverageValueMeter(self)
    用於統計任意添加的變量的方差和均值,可以用來測量平均損失等
    add()的輸入必須時number類型,另外在add的時候可以有一個可選的參數n,表示對應值的權重

  1. avm = tnt.AverageValueMeter() 
  2. for i=1,10 do  
  3. avm:add(i,10-i) 
  4. end 
  5. print(avm:value()) -- 輸出 4 2.4720... 
  • tnt.AUCMeter(self)
    對於二分類問題計算Area Under Curve (AUC).
    AUCMeter操作的變量是1D的tensor

  • tnt.ConfusionMeter(self,k[,nirmalized])
    多類之間的混淆矩陣,注意不是多類多標簽問題,多標簽是指一個類的實例可能分配多個標簽,這類問題參見tnt.MultiLabelConfusionMeter
    初始化的時候,需要指定類別數k,normalized指定是否將confuse matrix 歸一化,歸一化之后輸出的是百分比,否則是數值
    add(output,target) 輸入都是的tensor,這里為什么每次都是N個樣本一起輸入呢?這是因為往往訓練模型都是Batch模式處理的,target可以是N的tensor,每個值表示對應類別標號,也可以時NK的tensor表示類別的one-hot vector
    value()返回K
    K的混淆矩陣行表示groundtruth,列表示predicted targets

  • tnt.mAPMeter(self)
    統計所有類別之間的平均正確率,和APMeter參數完全一致,不同的時value()返回的是多個類別總的正確率

  • tnt.MovingAverageValueMeter(self,windowsize)
    該meter和AverageValueMeter非常類似,輸入的也是number,不同在於他統計的不是所有的number的均值和方差,而是往前windowsize時間窗內的numbers的均值和方差,windowsize在初始化時需要指定

  • tnt.MultiLabelConfusionMeter(self,k[,normalized])
    多類多標簽混淆矩陣,這個沒接觸過,不知道理解對不對,先放這吧,需要的時候再看

The tnt.MultiLabelConfusionMeter constructs a confusion matrix for multi- label, multi-class classification problems. In constructing the confusion matrix, the number of positive predictions is assumed to be equal to the number of positive labels in the ground-truth. Correct predictions (that is, labels in the prediction set that are also in the ground-truth set) are added to the diagonal of the confusion matrix. Incorrect predictions (that is, labels in the prediction set that are not in the ground-truth set) are equally divided over all non-predicted labels in the ground-truth set.

At initialization time, the k parameter that indicates the number of classes in the classification problem under consideration must be specified. Additionally, an optional parameter normalized (default = false) may be specified that determines whether or not the confusion matrix is normalized (that is, it contains percentages) or not (that is, it contains counts).

The add(output, target) method takes as input an NxK tensor output that contains the output scores obtained from the model for N examples and K classes, and a corresponding NxK-tensor target that provides the targets for the N examples using one-hot vectors (that is, vectors that contain only zeros and a single one at the location of the target value to be encoded).

  • tnt.ClassErrorMeter(self[,topk][,accuracy])
    參數: topk = table
    accuracy = boolean
    該meter用於統計分類誤差,topk是一個table指定分別統計前k類預測誤差,如ImageNet Competition中的Top5類誤差,accuracy表示返回的是正確了還是錯誤率,accuracy=true,返回的就是1-error
    add(output,target),output是一個的tensor,target可以使一個N的tensor也可以是一個的tensor,參考之前的AUCMeter
    value()返回的時topk誤差,value(k)返回的是第topk類誤差

  • tnt.TimeMeter(self[,unit])
    這個Meter用於統計events之間的時間,也可以用來統計batch數據的平均處理數據。她很特別!
    unit在初始的時候給定,是一個布爾值,默認false,當設置為true時,返回值將會被incUnit()值平均,計算平均時間消耗。
    tnt.TimeMeter提供的方法有:

    • reset() 重置timer,unit counter

    • stop() stop the timer

    • resume() 喚醒timer

    • incUnit() uint+1

    • value() 返回從reset()到現在的時間消耗

  • tnt.PrecisionAtKMeter(self[,topk][,dim][,online])

待補充
  • tnt.RecallMeter(self[,threshold][,preclass])
    統計threshold下的召回率,threshold是一個table類型,每個元素是一個閾值,默認值為0.5. perclass是一個布爾值,表示是單獨統計每一類的召回率還是統計整個召回率,默認值是false
    add(output,target) output是N*K的概率矩陣,行和為1;target是NK的二值矩陣,不一定行和為1,如{0,1,0,1}
    value()返回的是table值,對應的是threshold table中指定閾值下的召回率,如果perclass = true,那么table的每個元素就是一個table

  • tnt.PrecisionMeter(self[,threshold][,perclass])
    參考RecallMeter,這里計算的是正確率

  • tnt.NDCGMeter(self[,K])
    計算normalized discounted cumulative gain,沒使用過。。。。

tnt.Log

Log是一個由sting key索引的table,這些keys必須在構造函數中指定,有一個特殊的鍵 __status__可以在log:status()函數中設置用於記錄一些基本的messages

Log中提供的一些closures以及對應attached events
* onSet(log,key,value) 對應着給鍵賦值 log:set{}
* onGet(log,key) 對應着讀取key對應的值 log:get()
* onFlush(log) 對應着清空log log:flush()
* onClose(log) 對應log:close() 關閉log

示例:

  1. tnt = require'torchnet' 
  2. logtext = require 'torchnet.log.view.text' 
  3. logstatus = require 'torchnet.log.view.status' 
  4. log = tnt.log{ 
  5. keys = {'loss','accuracy'
  6. onFlush = { 
  7. -- write out all keys in "log" file 
  8. logtext{filename='log.txt', keys={"loss", "accuracy"}, format={"%10.5f", "%3.2f"}}, 
  9. -- write out loss in a standalone file 
  10. logtext{filename='loss.txt', keys={"loss"}}, 
  11. -- print on screen too 
  12. logtext{keys={"loss", "accuracy"}}, 
  13. }, 
  14. onSet = { 
  15. -- add status to log 
  16. logstatus{filename='log.txt'}, 
  17. -- print status to screen 
  18. logstatus{}, 


  19.  
  20. -- set values 
  21. log:set{ 
  22. loss = 0.1
  23. accuracy = 97 

  24.  
  25. -- write some info 
  26. log:status("hello world"
  27.  
  28. -- flush out log 
  29. log:flush() 
  30.  

后面我們來看一個具體的例子,以VGG16為例實現一個Siamese CNN網絡計算patch之間的相似度



免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM