torchnet package (1)


torchnet package (1)

torchnet

torchnet是用於torch的代碼復用和模塊化編程的框架,主要包含四個類

  • Dataset 以不同的方式對數據進行預處理

  • Engine 訓練/測試機器學習方法

  • Meter 評估方法性能

  • Log 日志

Documentation

torchnet的調用
local tnt = require 'torchnet'

tnt.Dataset()

torchnet提供了多種即插即用的數據容器(data container),例如 concat,split,batch,resample,etc ... 操作。
tnt.Dataset()實例包含兩種主要方法

  • dataset:size() 返回數據集的大小

  • dataset:get(idx) 其中idx是1到size中的數字,返回數據集的第idx個樣本

盡管可以簡單的通過for loop循環實現數據集的迭代,為了用戶能夠以on-the-fly manner找出某些樣本或者並行的數據讀取,torchnet還提供了一些DatasetIterator類型的迭代器

在torchnet中,dataset:get()返回的可以是一個Lua table。table中的閾值可以是任意的,即使大多數的數據集都是tensor類型。

需要注意的是,並不能直接的使用tnt.Dataset()創建該類型,該類型類似於一個抽象類,其下面的具體類包括batchdataset,splitdataset等

-tnt.ListDataset(self,list,load,[,path])
參數:self = tnt.ListDataset
list = tds.Hash
load = function
[path = string]
其中list可以是tds.Hash,table或者torch.LongTensor類型,當訪問第i個樣本時返回的是load(list[i]),這里load() 是由user提供的閉包函數
當path參數非空的時候,list對應的應該是string隊列,這樣傳遞給load()函數的參數自動加上path前綴,比如訪問文件夾'd:/data/mot2015/'下的數據時,不同的子數據集存放在不同的文件里'1.txt','2.txt','3.txt',...這時候 list={'1.txt','2.txt',...},path='d:/data/mot2015/',那么load(x)內的x=path .. x

  1. a={{1,2,3},{2,3,4},{2,2,2}} 
  2. b=torch.Tensor(a) 
  3. f=tnt.ListDataset({list=b:long(),function(x) return x:sum() end}) 
  4. print(f:size()) -- 3 
  5. print(f:get(1)) -- 6 

注意list只能是hash,table或者longtensor這里容易出現錯誤的是習慣使用:long()將tensor類型轉換,但是對於元素含小數部分的tensor直接類型轉換會出現錯誤!

  • tnt.ListDataset(self,filename,load[,maxload][,path])
    參數: self = tnt.ListDataset
    filename = string 這里filename指定的文件的每一行都是list的一個元素,類似於io.lines(filename)
    load = function 閉包函數
    [maxload = number] 最大加載條目數
    [path = string] 同之前

  • tnt.TableDataset(self,data)
    參數: data = table 針對於小型數據集,data必須Hash索引,對data數據淺層拷貝

  1. a= tnt.TableDataset{data={1,2,3}} 
  2. print(a:get(1)) 

tnt.TableDataset假定table中key從1連續

  • tnt.TransformDataset(self,dataset,transform[,key])
    參數: self = tnt.TransformDataset
    dataset = tnt.Dataset
    transform = function -- 變換函數
    [key = string] -- 需要變換的key值,如果沒有則對dataset中所有數據操作
    當使用tnt.Dataset:get()查詢數據集中的數據時,tnt.TransformDataset()以on-the-fly 方式執行閉包函數transform並返回值。
    on-the-fly我的理解是不需要中斷過程去執行閉包函數,不涉及從內存中讀取數據,而是直接通過cache形式執行,速度很快

  1. a=torch.Tensor{{1,2,3,4},{2,3,4,4}}:long() 
  2. ldata=tnt.ListDataset({list=a,load=function(x) return x end}) 
  3. tdata=tnt.TransformDataset({dataset=ldata,transform=function(x) return x-10 end}) 
  4. print(tdata:get(1)) 

-- tnt.TransformDataset(self,dataset,transforms)
注意這個方法是transforms 是一個table,table中的鍵值對應着dataset[list[i]]的域,如果我們使用tnt.TableDataset{{a=1,b=2,c=3},{a=0,b=3,c=5}}創建Dataset,如下

  1. Tdata = tnt.TableDataset{data={{a=1,b=2,c=3},{a=0,b=3,c=5}}} 
  2. f=tnt.TransformDataset({dataset = Tdata,transforms={a=function(x) return 2*x end,b=function(x) return x-20 end}}) 
  3. f:get(1) -- 這時候輸出{a:2 b:-18 c:3},即Tdata[i]的域a執行了transforms.a函數,域b執行transforms.b函數 

-- tnt.BatchDataset(self,dataset,batchsize[,perm][,merge][,policy][,filter])
參數:self = tnt.BatchDataset
dataset = tnt.Dataset
batchsize = number
[perm = function]
[merge = function]
[policy = string]
[filter = function]
功能:將dataset中的batchsize個樣本組成一個樣本,方便batch處理
merge函數主要是將batchsize個樣本的不同域組合起來,比如數據集的第i個樣本寫作
{input = <input_i>,target = <target_i>}
那么merge()使數據組合為

  1. {<input_i_1>,<ingput_i_2>,... <input_i_n>} 和 {<target_i_1>,<target_i_2>,... <target_i_n>} 
  1. ldata = tnt.ListDataset({ 
  2. list = torch.range(1,40):long(), 
  3. load = function(x) return {input={torch.randn(2,2),torch.randn(3,3)},target =x,target_t = -x } end 
  4. }) 
  5. bdata = tnt.BatchDataset{ 
  6. dataset = ldata, 
  7. batchsize=10 

  8. print(bdata:size()) --輸出 4 
  9. print(bdata:get(1)) -- 輸出第一個batch,包含3個field:target,input,target_t 
  10. print(bdata:get(1).input[1]) -- 輸出一個input元素  

batch方式操作時,shuffle很重要,所以perm(idx,size)是一個閉包函數,該函數返回shuffle之后idx位置索引的樣本,size是dataset的大小。
dataset的size可能不能被batchsize整除,於是 policy指定了截取方式
* include-last 不能整除的最后一個batch大小非必要等於batchsize
* skip-last 最后余出的部分樣本舍掉,這並不意味着那些樣本就不用了,因為shuffle后的樣本排序不定
* divisible-only 不能整除則報錯

  • tnt.CoroutineBatchDataset(self,dataset,batchsize[,perm][,merge][,policy][,filter])
    該方法和BatchDataset方法參數完全一致,實現的功能也幾乎一致,唯一不同的地方是該方法可以用於協同程序,用到的時候再看吧。。。

  • tnt.ConcatDataset(self,datasets)
    參數: self= tnt.ConcatDataset
    dataset = table
    功能:將table中的數據集concate

  • tnt.ResampleDataset(self,dataset[,sampler][,size])
    給定一個數據集dataset,然后通過sampler(dataset,idx)閉包函數重采樣獲得新的數據集,size可以指定resample數據集的大小,若沒指定則與原來的dataset大小相同,通過源碼我們可以看到sampler這個函數其實是用來實現idx的改變

  1. ldata = tnt.ListDataset{list = torch.range(1,40):long(), 
  2. load = function(x) return {input={torch.randn(2,2),torch.randn(3,3)},target =x,target_t = -x } end
  3. iidx = tnt.transform.randperm(ldata:size()) 
  4. rdata = tnt.ResampleDataset{dataset = ldata,sampler = function(dataset,idx) return iidx(idx) end} --這其實實現了shuffle功能 
  5. print(rdata:get(1)) 
  • tnt.ShuffleDataset(self,dataset[,size][,replacement])
    實現dataset的shuffle,如果replacement=true,那么指定的size可以大於dataset:size(),大於的部分通過redraw獲得
    tnt.ShuffleDataset.resample(self)
    通過該函數在構建ShuffleDataset時就創建fixed的permutation,能夠保證多次index同一個值得到的結果相同

  • tnt.SplitDataset(self,dataset,partitions[,initialpartition])
    partitions = table
    [initialpartition = string]
    partitions是一個lua table,table中的元素<key,value>,key是對應partition的名,value是一個0-1的數表示取dataset:size()的比例,或者直接是個number表示對應partitions的大小,initialpartition指定了初始化時加載的partition
    注意 ,該方法在交叉驗證時,用起來很爽
    tnt.SplitDataset.select(self,partition) 改變當前選擇的partition

  1. sdata = tnt.SplitDataset{data=ldata,partitions={train=0.5,ver=0.25,test=0.25}} 
  2. sdata:select('train') --因為沒有指定initialpartition所以需要指定當前的partition才能訪問,當指定initialpartition后,該行可以不要,如 sdata = tnt.SplitDataset{data=ldata,partitions={train=0.5,ver=0.25,test=0.25},initialpartitial='train'} 
  3. print(sdata:get(1)) 
  4. print(sdata:size()) 

tnt.utils

torchnet提供了許多工具函數

  • tnt.utils.table.clone(table) 實現table的深度拷貝

  • tnt.utils.table.merge(dst,src) 將src合並到dst中,實現的是淺層拷貝,如果src中的key在dst中已經存在,則覆蓋dst中的key值

  1. src={{1,2,3},{4,5,6}} 
  2. dst1={} 
  3. dst2={{1,2,3}} 
  4. dst1=tnt.utils.table.clone(src) 
  5. dst1[1][2]=10 
  6. print(dst1) -- 此時dst1[1][2]=10 
  7. print(src) -- src[1][2]=2 
  8. tnt.utils.table.merge(dst2,src) 
  9. dst2[1][2]=10 
  10. print(dst2) -- dst2[1][2]=10 
  11. print(src) -- src[1][2]=10 
  12. src={a={1,2,3},b={2,3,4}} 
  13. dst1={c={2,2,2}} 
  14. dst2={a={2,3}} 
  15. tnt.utils.table.merge(dst1,src) 
  16. tnt.utils.table.merge(dst2,src) 
  17. print(dst1) -- dst1包含三個元素a,b,c 
  18. print(dst2) -- dst2僅包含2個元素a,b,其中dst2中原來的a被src中的a覆蓋 
  • tnt.utils.table.foreach(tbl,closure[,recursive])
    參數: tbl 是一個lua table; closure 閉包函數; [recursive = boolean] 默認值為false
    功能: 對tbl中的每一個元素執行closure函數,如果recursive=true那么tbl將被遞歸的采用closure函數
    示例:

  1. a={{1,2,3},{2,3,4},{{2,2,2},{1,1,1}}} 
  2. fun = function(v)print('------');print(v)end 
  3. tnt.utils.table.foreach(a,fun) 
  4. tnt.utils.table.foreach(a,fun,true

輸出:

  1. ------ 

  2. 1 : 1 
  3. 2 : 2 
  4. 3 : 3 

  5. ------ 

  6. 1 : 2 
  7. 2 : 3 
  8. 3 : 4 

  9. ------ 

  10. 1 : 

  11. 1 : 2 
  12. 2 : 2 
  13. 3 : 2 

  14. 2 : 

  15. 1 : 1 
  16. 2 : 1 
  17. 3 : 1 


  1. ------ 

  2. ------ 

  3. ------ 

  4. ------ 

  5. ------ 

  6. ------ 

  7. ------ 

  8. ------ 

  9. ------ 

  10. ------ 

  11. ------ 

  12. ------ 

可以發現,recursive = true下遞歸調用表中元素,直至最里層的單個元素,而在false下,最外層table中每個元素作為輸入參數輸入到closure函數中

  • tnt.utils.table.canmergetensor(tbl)
    tbl是否能夠merge成一個tensor,table中元素是相同規模的tensor則可以mergetensor

  • tnt.utils.table.mergetensor(tbl)
    將tbl中的元素合並成tensor

  1. a={torch.Tensor(3,2),torch.Tensor(3,3)} 
  2. b={torch.Tensor(3):float(),torch.Tensor(3):double()} 
  3. c={torch.Tensor(4),torch.Tensor(4)} 
  4. var={a,b,c} 
  5. for i=1,3 do 
  6. if tnt.utils.table.canmergetensor(var[i]) then 
  7. print(i) 
  8. tnt.utils.table.mergetensor(var[i]) 
  9. end 
  10. end 

此時顯示b,c可以mergetensor,說明只要tensor的規模相同就可以,與其type是否一致無關

tnt.transform

該package提供了數據的基本變換,這些變換有的直接作用在數據上,有的作用在數據結構上,使得操作tnt.Dataset非常方便
這些變換雖然都很簡單,但是這些邊還可以通過compose或者merge方式實現復雜的變換,compose就是將變換串起來,merge是將變換同時執行,返回每個變換的結果

  • transform.identity(...)
    該變換返回輸入本身,這個暫時沒想到使用的地方

  • transform.compose(transforms)
    其中參數transforms是一個函數列表,每個函數可以實現一種變換。注意該函數認為transforms中的函數是從1開始連續索引的,如果碰到不連續的了,那么只執行前面連續索引的變換

  1. transform=tnt.transform 
  2. f=transform.compose({ 
  3. function(x) return 2*x end
  4. function(x) return x+10 end
  5. foo = function(x) return x/2 end 
  6. }) 
  7. a={2,3,4
  8. _ =tnt.utils.table.foreach(a,function(x) print(f(x)) end) 輸出 141618,即只執行了f中前兩個變換 

注意這里函數列表寫成{[1]=function(x) return 2*x end,function(x) return x+10 end,foo = function(x) return x/2 end}則只執行第一個變換,因為key:[2]不存在

  • transform.merge(transforms)
    transforms是一個變換函數列表,對於一個輸入,該函數使該輸入經過所有變換函數得到的結果merge成table輸出

  1. f = transform.merge{ 
  2. [1] = function(x) return torch.Tensor{2*x} end
  3. [2] = function(x) return torch.Tensor{x + 10} end
  4. [3] = function(x) return torch.Tensor{x / 2} end
  5. [4] = function(x) return torch.Tensor{x} end 

  6. f(3

注意這個例子輸出的是一個tensor,並不是doc中說的輸出一個table,我覺得這個函數除了bug,transform.lua的第144行應該直接return newz就可以了,源代碼中使用utils.table.mergetensor(newz)反而會導致合並出錯,要想讓源代碼能執行就必須像上面我給的例子似的,函數返回的是同等規模的tensor,且函數列表中的index必須是從1開始連續索引,源代碼要是不改這個函數還是得特別注意

  • transform.tableapply(transform)
    這里的參數transform是一個變換函數,該變換作用於table變量

  1. a={1,2,3,4
  2. f=transform.tableapply(function(x) return x*2 end
  3. f(a) 
  • transform.tablemergekeys()
    得到的變換方法的輸入必須是一個table的table

  1. x={{input=1,target='a'},{input=2,target='b',flag='hard'}} 
  2. transform.tablemergekeys(x) 

注意這個源碼也有問題,源碼transform.lua中的243行中ipairs應該修改為pairs,否則給的例子運行不了,因為ipairs從1開始index到第一個非整數key就結束了

  • transform.randperm(size)
    randperm()函數,注意該函數返回的是一個函數句柄,想要獲得第i個值,應該用f=transform.randperm(10);f(i)

  • trandform.normalize([threshold])
    輸入必須是一個tensor,該函數能夠實現標准化,即中心化+歸一化,參數threshold是一個number,只有標准差大於threshold時,tensor才會normalize

  1. a=torch.rand(2,3)*10 
  2. print('the std of a is ' .. a:std()) 
  3. f=transform.normalize() 
  4. print('the std of normalized a is ' .. f(a):std() .. ' and the mean is ' .. f(a):sum()) 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM