(原)torch的訓練過程


轉載請注明出處:

http://www.cnblogs.com/darkknightzh/p/6221622.html

參考網址:

http://ju.outofmemory.cn/entry/284587

https://github.com/torch/nn/blob/master/doc/criterion.md

1. 使用updateParameters

假設已經有了model=setupmodel(自己建立的模型),同時也有自己的訓練數據input,實際輸出outReal,以及損失函數criterion(參見第二個網址),則使用torch訓練過程如下:

1 -- given model, criterion, input, outReal
2 model:training()
3 model:zeroGradParameters()
4 outPredict = model:forward(input)
5 err= criterion:forward(outPredict, outReal)
6 grad_criterion = criterion:backward(outPredict, outReal)
7 model:backward(input, grad_criterion)
8 model:updateParameters(learningRate)

上面第1行假定已知的參數

第2行設置為訓練模式

第3行將model中每個模塊保存的梯度清零(防止之前的干擾此次迭代)

第4行將輸入input通過model,得到預測的輸出outPredict

第5行通過損失函數計算在當前參數下模型的預測輸出outPredict和實際輸出outReal的誤差err

第6行通過預測輸出outPredict和實際輸出outReal計算損失函數的梯度grad_criterion

第7行反向計算model中每個模塊的梯度

第8行更新model每個模塊的參數

 

每次迭代時,均需要執行第3行至第8行。

 

=========================================================

2. 使用optim

170301更新:

http://x-wei.github.io/learn-torch-6-optim.html

中給出了更方便的方式(是否方便也說不清楚),可以使用torch中的optim來更新參數(直接使用model:updateParameters的話,只能使用最簡單的梯度下降算法,optmi中封裝了很多算法,梯度下降,adam之類的)。

params_new, fs, ... = optim._method_(feval, params[, config][, state])

其中,param:當前參數向量(1D的tensro),在優化時會被更新

feval:用戶自定義的閉包,類似於f, df/dx = feval(x)

config:一個包含算法參數(如learning rate)的table

state:包含狀態變量的table

params_new:最小化函數f的新的結果參數(1D的tensor)

fs:a table of f values evaluated during the optimization, fs[#fs] is the optimized function value

注意:由於optmi需要輸入數據為1D的tensor,因而需要將模型中的參數變成拉平,通過下面的函數來實現:

params, gradParams = model:getParameters()

params和gradParams均為1D的tensor。

使用上面的方法后,開始得程序可以修改為:

-- given model, criterion, input, outReal, optimState
local params, gradParams = model:getParameters()

local function feval()
    return criterion.output, gradParams 
end

for ...
    model:training()
    model:zeroGradParameters()
    outPredict = model:forward(input)
    err= criterion:forward(outPredict, outReal)
    grad_criterion = criterion:backward(outPredict, outReal)
    model:backward(input, grad_criterion)
    
    optim.sgd(feval, params, optimState)
end

 170301更新結束

=========================================================

3. 使用model:backward注意的問題

170405更新

需要注意的是,model:backward一定要和model:forward對應。

https://github.com/torch/nn/blob/master/doc/module.md中[gradInput] backward(input, gradOutput)寫着:

In general this method makes the assumption forward(input) has been called before, with the same input. This is necessary for optimization reasons. If you do not respect this rule, backward() will compute incorrect gradients.

應該是由於backward時,可能會使用forward的某些中間變量,因而backward執行前,必須先執行forward,否則中間變量和backward不對應,導致結果錯誤。

我這邊之前的程序由於最初forward后,保存的是最后一次forward時的中間變量,因而backward時的結果總是不正確(見method5注釋的那句)。

只能使用比較坑的方式解決,之前先forward,最終在backward之前,在forward一次,這樣能保證結果正確(缺點就是增加了一次計算。。。),代碼如method5。

說明:method1為常規的batch的方法。但是該方法對顯存要求較高。因而可以使用類似caffe中的iter_size的方式,如method2的方法(和caffe中的iter_size不完全一樣)。如果需要使用更多的樣本,同時criterion時使用盡可能多的樣本,則前兩種方法均會出現問題,此時可以使用method3的方法(但是實際上method3有問題,loss收斂的很慢)。method4對method3進行了進一步的改進及測試,如果method4注釋那兩行,則其收斂正常,但是不注釋那兩行,則收斂出現問題,和method3收斂類似。method5進行了最終的改進。該程序能正常收斂。同時為了驗證forward和backward要對應,將method5中注釋的取消注釋,同時注釋掉上面一行,可以看出,其收斂很慢(和method3,4類似)。下面是各種method前10次的的收斂情況。

程序如下:

  1 require 'torch'
  2 require 'nn'
  3 require 'optim'
  4 require 'cunn'
  5 require 'cutorch'
  6 local mnist = require 'mnist'
  7 
  8 local fullset = mnist.traindataset()
  9 local testset = mnist.testdataset()
 10 
 11 local trainset = {
 12     size = 50000,
 13     data = fullset.data[{{1,50000}}]:double(),
 14     label = fullset.label[{{1,50000}}]
 15 }
 16 trainset.data = trainset.data - trainset.data:mean()
 17 trainset.data = trainset.data:cuda()
 18 trainset.label = trainset.label:cuda()
 19 
 20 local validationset = {
 21     size = 10000,
 22     data = fullset.data[{{50001,60000}}]:double(),
 23     label = fullset.label[{{50001,60000}}]
 24 }
 25 validationset.data = validationset.data - validationset.data:mean()
 26 validationset.data = validationset.data:cuda()
 27 validationset.label = validationset.label:cuda()
 28 
 29 local model = nn.Sequential()
 30 model:add(nn.Reshape(1, 28, 28))
 31 model:add(nn.MulConstant(1/256.0*3.2))
 32 model:add(nn.SpatialConvolutionMM(1, 20, 5, 5, 1, 1, 0, 0))
 33 model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0))
 34 model:add(nn.SpatialConvolutionMM(20, 50, 5, 5, 1, 1, 0, 0))
 35 model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0))
 36 model:add(nn.Reshape(4*4*50))
 37 model:add(nn.Linear(4*4*50, 500))
 38 model:add(nn.ReLU())
 39 model:add(nn.Linear(500, 10))
 40 model:add(nn.LogSoftMax())
 41 
 42 model = require('weight-init')(model, 'xavier')
 43 model = model:cuda()
 44 
 45 x, dl_dx = model:getParameters()
 46 
 47 local criterion = nn.ClassNLLCriterion():cuda()
 48 
 49 local sgd_params = {
 50    learningRate = 1e-2,
 51    learningRateDecay = 1e-4,
 52    weightDecay = 1e-3,
 53    momentum = 1e-4
 54 }
 55 
 56 local training = function(batchSize)
 57     local current_loss = 0
 58     local count = 0
 59     local shuffle = torch.randperm(trainset.size)
 60     batchSize = batchSize or 200
 61     for t = 0, trainset.size-1, batchSize do
 62         -- setup inputs and targets for batch iteration
 63         local size = math.min(t + batchSize, trainset.size) - t
 64         local inputs = torch.Tensor(size, 28, 28):cuda()
 65         local targets = torch.Tensor(size):cuda()
 66         for i = 1, size do
 67             inputs[i] = trainset.data[shuffle[i+t]]
 68             targets[i] = trainset.label[shuffle[i+t]] + 1
 69         end
 70 
 71         local feval = function(x_new)
 72             local miniBatchSize = 20
 73             if x ~= x_new then x:copy(x_new) end   -- reset data
 74             dl_dx:zero()
 75 
 76             --[[ ------------------ method 1 original batch
 77             local outputs = model:forward(inputs)
 78             local loss = criterion:forward(outputs, targets)
 79             local gradInput = criterion:backward(outputs, targets)
 80             model:backward(inputs, gradInput)
 81             --]]
 82 
 83             --[[ ------------------ method 2 iter-size with batch
 84             local loss = 0
 85             for idx = 1, batchSize, miniBatchSize do
 86                 local outputs = model:forward(inputs[{{idx, idx + miniBatchSize - 1}}])
 87                 loss = loss + criterion:forward(outputs, targets[{{idx, idx + miniBatchSize - 1}}])
 88                 local gradInput = criterion:backward(outputs, targets[{{idx, idx + miniBatchSize - 1}}])
 89                 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput)
 90             end
 91             dl_dx:mul(1.0 * miniBatchSize / batchSize)
 92             loss = loss * miniBatchSize / batchSize
 93             --]]
 94 
 95             --[[  ------------------ method 3 mini-batch in batch
 96             local outputs = torch.Tensor(batchSize, 10):zero():cuda()
 97             for idx = 1, batchSize, miniBatchSize do
 98                 outputs[{{idx, idx + miniBatchSize - 1}}]:copy(model:forward(inputs[{{idx, idx + miniBatchSize - 1}}]))
 99             end
100             local loss = 0
101             for idx = 1, batchSize, miniBatchSize do
102                 loss = loss + criterion:forward(outputs[{{idx, idx + miniBatchSize - 1}}], 
103                     targets[{{idx, idx + miniBatchSize - 1}}])
104             end
105             local gradInput = torch.Tensor(batchSize, 10):zero():cuda()
106             for idx = 1, batchSize, miniBatchSize do
107                 gradInput[{{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward(
108                     outputs[{{idx, idx + miniBatchSize - 1}}], targets[{{idx, idx + miniBatchSize - 1}}]))
109             end
110             for idx = 1, batchSize, miniBatchSize do
111                 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput[{{idx, idx + miniBatchSize - 1}}])
112             end
113             dl_dx:mul( 1.0 * miniBatchSize / batchSize)
114             loss = loss * miniBatchSize / batchSize
115             --]]
116 
117             --[[ ------------------ method 4 mini-batch in batch
118             local outputs = torch.Tensor(batchSize, 10):zero():cuda()
119             local loss = 0
120             local gradInput = torch.Tensor(batchSize, 10):zero():cuda()
121             for idx = 1, batchSize, miniBatchSize do
122                 outputs[{{idx, idx + miniBatchSize - 1}}]:copy(model:forward(inputs[{{idx, idx + miniBatchSize - 1}}]))
123                 loss = loss + criterion:forward(outputs[{{idx, idx + miniBatchSize - 1}}], 
124                     targets[{{idx, idx + miniBatchSize - 1}}])
125                 gradInput[{{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward(
126                     outputs[{{idx, idx + miniBatchSize - 1}}], targets[{{idx, idx + miniBatchSize - 1}}]))
127             --  end
128             --  for idx = 1, batchSize, miniBatchSize do
129                 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput[{{idx, idx + miniBatchSize - 1}}])
130             end
131            
132             dl_dx:mul( 1.0 * miniBatchSize / batchSize)
133             loss = loss * miniBatchSize / batchSize
134             --]]
135 
136 
137             ---[[ ------------------ method 5 mini-batch in batch
138             local loss = 0
139             local gradInput = torch.Tensor(batchSize, 10):zero():cuda()
140 
141             for idx = 1, batchSize, miniBatchSize do
142                 local outputs = model:forward(inputs[{{idx, idx + miniBatchSize - 1}}])
143                 loss = loss + criterion:forward(outputs, targets[{{idx, idx + miniBatchSize - 1}}])
144                 gradInput[{{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward(outputs, targets[{{idx, idx + miniBatchSize - 1}}]))
145             end
146 
147             for idx = 1, batchSize, miniBatchSize do
148                 model:forward(inputs[{{idx, idx + miniBatchSize - 1}}])
149                 --model:forward(inputs[{{batchSize - miniBatchSize + 1, batchSize}}])
150                 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput[{{idx, idx + miniBatchSize - 1}}])
151             end
152 
153             dl_dx:mul( 1.0 * miniBatchSize / batchSize)
154             loss = loss * miniBatchSize / batchSize
155             --]]
156 
157             return loss, dl_dx
158         end
159 
160         _, fs = optim.sgd(feval, x, sgd_params)
161 
162         count = count + 1
163         current_loss = current_loss + fs[1]
164     end
165 
166     return current_loss / count   -- normalize loss
167 end
168 
169 local eval = function(dataset, batchSize)
170     local count = 0
171     batchSize = batchSize or 200
172 
173     for i = 1, dataset.size, batchSize do
174         local size = math.min(i + batchSize - 1, dataset.size) - i
175         local inputs = dataset.data[{{i,i+size-1}}]:cuda()
176         local targets = dataset.label[{{i,i+size-1}}]
177         local outputs = model:forward(inputs)
178         local _, indices = torch.max(outputs, 2)
179         indices:add(-1)
180         indices = indices:cuda()
181         local guessed_right = indices:eq(targets):sum()
182         count = count + guessed_right
183     end
184 
185     return count / dataset.size
186 end
187 
188 
189 local max_iters = 50
190 local last_accuracy = 0
191 local decreasing = 0
192 local threshold = 1 -- how many deacreasing epochs we allow
193 for i = 1,max_iters do
194    -- timer = torch.Timer()
195 
196     model:training()
197     local loss = training()
198 
199     model:evaluate()
200     local accuracy = eval(validationset)
201     print(string.format('Epoch: %d Current loss: %4f; validation set accu: %4f', i, loss, accuracy))
202     if accuracy < last_accuracy then
203         if decreasing > threshold then break end
204         decreasing = decreasing + 1
205     else
206         decreasing = 0
207     end
208     last_accuracy = accuracy
209 
210     --print('    Time elapsed: ' .. i .. 'iter: ' .. timer:time().real .. ' seconds')
211 end
212 
213 testset.data = testset.data:double()
214 eval(testset)
View Code

weight-init.lua

 1 --
 2 -- Different weight initialization methods
 3 --
 4 -- > model = require('weight-init')(model, 'heuristic')
 5 --
 6 require("nn")
 7 
 8 
 9 -- "Efficient backprop"
10 -- Yann Lecun, 1998
11 local function w_init_heuristic(fan_in, fan_out)
12    return math.sqrt(1/(3*fan_in))
13 end
14 
15 -- "Understanding the difficulty of training deep feedforward neural networks"
16 -- Xavier Glorot, 2010
17 local function w_init_xavier(fan_in, fan_out)
18    return math.sqrt(2/(fan_in + fan_out))
19 end
20 
21 -- "Understanding the difficulty of training deep feedforward neural networks"
22 -- Xavier Glorot, 2010
23 local function w_init_xavier_caffe(fan_in, fan_out)
24    return math.sqrt(1/fan_in)
25 end
26 
27 -- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"
28 -- Kaiming He, 2015
29 local function w_init_kaiming(fan_in, fan_out)
30    return math.sqrt(4/(fan_in + fan_out))
31 end
32 
33 local function w_init(net, arg)
34    -- choose initialization method
35    local method = nil
36    if     arg == 'heuristic'    then method = w_init_heuristic
37    elseif arg == 'xavier'       then method = w_init_xavier
38    elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe
39    elseif arg == 'kaiming'      then method = w_init_kaiming
40    else
41       assert(false)
42    end
43 
44    -- loop over all convolutional modules
45    for i = 1, #net.modules do
46       local m = net.modules[i]
47       if m.__typename == 'nn.SpatialConvolution' then
48          m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
49       elseif m.__typename == 'nn.SpatialConvolutionMM' then
50          m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
51       elseif m.__typename == 'cudnn.SpatialConvolution' then
52          m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
53       elseif m.__typename == 'nn.LateralConvolution' then
54          m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1))
55       elseif m.__typename == 'nn.VerticalConvolution' then
56          m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
57       elseif m.__typename == 'nn.HorizontalConvolution' then
58          m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
59       elseif m.__typename == 'nn.Linear' then
60          m:reset(method(m.weight:size(2), m.weight:size(1)))
61       elseif m.__typename == 'nn.TemporalConvolution' then
62          m:reset(method(m.weight:size(2), m.weight:size(1)))            
63       end
64 
65       if m.bias then
66          m.bias:zero()
67       end
68    end
69    return net
70 end
71 
72 return w_init
View Code
Method 1

Epoch: 1 Current loss: 0.616950; validation set accu: 0.920900	
Epoch: 2 Current loss: 0.228665; validation set accu: 0.942400	
Epoch: 3 Current loss: 0.168047; validation set accu: 0.957900	
Epoch: 4 Current loss: 0.134796; validation set accu: 0.961800	
Epoch: 5 Current loss: 0.113071; validation set accu: 0.966200	
Epoch: 6 Current loss: 0.098782; validation set accu: 0.968800	
Epoch: 7 Current loss: 0.088252; validation set accu: 0.970000	
Epoch: 8 Current loss: 0.080225; validation set accu: 0.971200	
Epoch: 9 Current loss: 0.073702; validation set accu: 0.972200	
Epoch: 10 Current loss: 0.068171; validation set accu: 0.972400	

method 2
Epoch: 1 Current loss: 0.624633; validation set accu: 0.922200	
Epoch: 2 Current loss: 0.238459; validation set accu: 0.945200	
Epoch: 3 Current loss: 0.174089; validation set accu: 0.959000	
Epoch: 4 Current loss: 0.140234; validation set accu: 0.963800	
Epoch: 5 Current loss: 0.116498; validation set accu: 0.968000	
Epoch: 6 Current loss: 0.101376; validation set accu: 0.968800	
Epoch: 7 Current loss: 0.089484; validation set accu: 0.972600	
Epoch: 8 Current loss: 0.080812; validation set accu: 0.973000	
Epoch: 9 Current loss: 0.073929; validation set accu: 0.975100	
Epoch: 10 Current loss: 0.068330; validation set accu: 0.975400	

method 3
Epoch: 1 Current loss: 2.202240; validation set accu: 0.548500	
Epoch: 2 Current loss: 2.049710; validation set accu: 0.669300	
Epoch: 3 Current loss: 1.993560; validation set accu: 0.728900	
Epoch: 4 Current loss: 1.959818; validation set accu: 0.774500	
Epoch: 5 Current loss: 1.945992; validation set accu: 0.757600	
Epoch: 6 Current loss: 1.930599; validation set accu: 0.809600	
Epoch: 7 Current loss: 1.911803; validation set accu: 0.837200	
Epoch: 8 Current loss: 1.904754; validation set accu: 0.842100	
Epoch: 9 Current loss: 1.903705; validation set accu: 0.846400	
Epoch: 10 Current loss: 1.903911; validation set accu: 0.848100	

method 4
Epoch: 1 Current loss: 0.624240; validation set accu: 0.924900	
Epoch: 2 Current loss: 0.213469; validation set accu: 0.948500	
Epoch: 3 Current loss: 0.156797; validation set accu: 0.959800	
Epoch: 4 Current loss: 0.126438; validation set accu: 0.963900	
Epoch: 5 Current loss: 0.106664; validation set accu: 0.965900	
Epoch: 6 Current loss: 0.094166; validation set accu: 0.967200	
Epoch: 7 Current loss: 0.084848; validation set accu: 0.971200	
Epoch: 8 Current loss: 0.077244; validation set accu: 0.971800	
Epoch: 9 Current loss: 0.071417; validation set accu: 0.973300	
Epoch: 10 Current loss: 0.065737; validation set accu: 0.971600	


取消注釋
Epoch: 1 Current loss: 2.178319; validation set accu: 0.542200	
Epoch: 2 Current loss: 2.031493; validation set accu: 0.648700	
Epoch: 3 Current loss: 1.982282; validation set accu: 0.703700	
Epoch: 4 Current loss: 1.956709; validation set accu: 0.762700	
Epoch: 5 Current loss: 1.927590; validation set accu: 0.808100	
Epoch: 6 Current loss: 1.924535; validation set accu: 0.817200	
Epoch: 7 Current loss: 1.911364; validation set accu: 0.820100	
Epoch: 8 Current loss: 1.898206; validation set accu: 0.855400	
Epoch: 9 Current loss: 1.885394; validation set accu: 0.836500	
Epoch: 10 Current loss: 1.880787; validation set accu: 0.870200	


method 5

Epoch: 1 Current loss: 0.619814; validation set accu: 0.924300	
Epoch: 2 Current loss: 0.232870; validation set accu: 0.948800	
Epoch: 3 Current loss: 0.172606; validation set accu: 0.954900	
Epoch: 4 Current loss: 0.137763; validation set accu: 0.961800	
Epoch: 5 Current loss: 0.116268; validation set accu: 0.967700	
Epoch: 6 Current loss: 0.101985; validation set accu: 0.968800	
Epoch: 7 Current loss: 0.091154; validation set accu: 0.970900	
Epoch: 8 Current loss: 0.083219; validation set accu: 0.972700	
Epoch: 9 Current loss: 0.074921; validation set accu: 0.972800	
Epoch: 10 Current loss: 0.070208; validation set accu: 0.972800	


取消注釋,同時注釋上面一行

Epoch: 1 Current loss: 2.161032; validation set accu: 0.497500	
Epoch: 2 Current loss: 2.027255; validation set accu: 0.690900	
Epoch: 3 Current loss: 1.972939; validation set accu: 0.767600	
Epoch: 4 Current loss: 1.940982; validation set accu: 0.766000	
Epoch: 5 Current loss: 1.933135; validation set accu: 0.812800	
Epoch: 6 Current loss: 1.913039; validation set accu: 0.799300	
Epoch: 7 Current loss: 1.896871; validation set accu: 0.848800	
Epoch: 8 Current loss: 1.899655; validation set accu: 0.854400	
Epoch: 9 Current loss: 1.889465; validation set accu: 0.845700	
Epoch: 10 Current loss: 1.878703; validation set accu: 0.846400	
View Code

170301更新結束

=========================================================

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM