轉載請注明出處:
http://www.cnblogs.com/darkknightzh/p/6221622.html
參考網址:
http://ju.outofmemory.cn/entry/284587
https://github.com/torch/nn/blob/master/doc/criterion.md
1. 使用updateParameters
假設已經有了model=setupmodel(自己建立的模型),同時也有自己的訓練數據input,實際輸出outReal,以及損失函數criterion(參見第二個網址),則使用torch訓練過程如下:
1 -- given model, criterion, input, outReal 2 model:training() 3 model:zeroGradParameters() 4 outPredict = model:forward(input) 5 err= criterion:forward(outPredict, outReal) 6 grad_criterion = criterion:backward(outPredict, outReal) 7 model:backward(input, grad_criterion) 8 model:updateParameters(learningRate)
上面第1行假定已知的參數
第2行設置為訓練模式
第3行將model中每個模塊保存的梯度清零(防止之前的干擾此次迭代)
第4行將輸入input通過model,得到預測的輸出outPredict
第5行通過損失函數計算在當前參數下模型的預測輸出outPredict和實際輸出outReal的誤差err
第6行通過預測輸出outPredict和實際輸出outReal計算損失函數的梯度grad_criterion
第7行反向計算model中每個模塊的梯度
第8行更新model每個模塊的參數
每次迭代時,均需要執行第3行至第8行。
=========================================================
2. 使用optim
170301更新:
http://x-wei.github.io/learn-torch-6-optim.html
中給出了更方便的方式(是否方便也說不清楚),可以使用torch中的optim來更新參數(直接使用model:updateParameters的話,只能使用最簡單的梯度下降算法,optmi中封裝了很多算法,梯度下降,adam之類的)。
params_new, fs, ... = optim._method_(feval, params[, config][, state])
其中,param:當前參數向量(1D的tensro),在優化時會被更新
feval:用戶自定義的閉包,類似於f, df/dx = feval(x)
config:一個包含算法參數(如learning rate)的table
state:包含狀態變量的table
params_new:最小化函數f的新的結果參數(1D的tensor)
fs:a table of f values evaluated during the optimization, fs[#fs] is the optimized function value
注意:由於optmi需要輸入數據為1D的tensor,因而需要將模型中的參數變成拉平,通過下面的函數來實現:
params, gradParams = model:getParameters()
params和gradParams均為1D的tensor。
使用上面的方法后,開始得程序可以修改為:
-- given model, criterion, input, outReal, optimState local params, gradParams = model:getParameters() local function feval() return criterion.output, gradParams end for ... model:training() model:zeroGradParameters() outPredict = model:forward(input) err= criterion:forward(outPredict, outReal) grad_criterion = criterion:backward(outPredict, outReal) model:backward(input, grad_criterion) optim.sgd(feval, params, optimState) end
170301更新結束
=========================================================
3. 使用model:backward注意的問題
170405更新
需要注意的是,model:backward一定要和model:forward對應。
https://github.com/torch/nn/blob/master/doc/module.md中[gradInput] backward(input, gradOutput)寫着:
In general this method makes the assumption forward(input) has been called before, with the same input. This is necessary for optimization reasons. If you do not respect this rule, backward() will compute incorrect gradients.
應該是由於backward時,可能會使用forward的某些中間變量,因而backward執行前,必須先執行forward,否則中間變量和backward不對應,導致結果錯誤。
我這邊之前的程序由於最初forward后,保存的是最后一次forward時的中間變量,因而backward時的結果總是不正確(見method5注釋的那句)。
只能使用比較坑的方式解決,之前先forward,最終在backward之前,在forward一次,這樣能保證結果正確(缺點就是增加了一次計算。。。),代碼如method5。
說明:method1為常規的batch的方法。但是該方法對顯存要求較高。因而可以使用類似caffe中的iter_size的方式,如method2的方法(和caffe中的iter_size不完全一樣)。如果需要使用更多的樣本,同時criterion時使用盡可能多的樣本,則前兩種方法均會出現問題,此時可以使用method3的方法(但是實際上method3有問題,loss收斂的很慢)。method4對method3進行了進一步的改進及測試,如果method4注釋那兩行,則其收斂正常,但是不注釋那兩行,則收斂出現問題,和method3收斂類似。method5進行了最終的改進。該程序能正常收斂。同時為了驗證forward和backward要對應,將method5中注釋的取消注釋,同時注釋掉上面一行,可以看出,其收斂很慢(和method3,4類似)。下面是各種method前10次的的收斂情況。
程序如下:

1 require 'torch' 2 require 'nn' 3 require 'optim' 4 require 'cunn' 5 require 'cutorch' 6 local mnist = require 'mnist' 7 8 local fullset = mnist.traindataset() 9 local testset = mnist.testdataset() 10 11 local trainset = { 12 size = 50000, 13 data = fullset.data[{{1,50000}}]:double(), 14 label = fullset.label[{{1,50000}}] 15 } 16 trainset.data = trainset.data - trainset.data:mean() 17 trainset.data = trainset.data:cuda() 18 trainset.label = trainset.label:cuda() 19 20 local validationset = { 21 size = 10000, 22 data = fullset.data[{{50001,60000}}]:double(), 23 label = fullset.label[{{50001,60000}}] 24 } 25 validationset.data = validationset.data - validationset.data:mean() 26 validationset.data = validationset.data:cuda() 27 validationset.label = validationset.label:cuda() 28 29 local model = nn.Sequential() 30 model:add(nn.Reshape(1, 28, 28)) 31 model:add(nn.MulConstant(1/256.0*3.2)) 32 model:add(nn.SpatialConvolutionMM(1, 20, 5, 5, 1, 1, 0, 0)) 33 model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0)) 34 model:add(nn.SpatialConvolutionMM(20, 50, 5, 5, 1, 1, 0, 0)) 35 model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0)) 36 model:add(nn.Reshape(4*4*50)) 37 model:add(nn.Linear(4*4*50, 500)) 38 model:add(nn.ReLU()) 39 model:add(nn.Linear(500, 10)) 40 model:add(nn.LogSoftMax()) 41 42 model = require('weight-init')(model, 'xavier') 43 model = model:cuda() 44 45 x, dl_dx = model:getParameters() 46 47 local criterion = nn.ClassNLLCriterion():cuda() 48 49 local sgd_params = { 50 learningRate = 1e-2, 51 learningRateDecay = 1e-4, 52 weightDecay = 1e-3, 53 momentum = 1e-4 54 } 55 56 local training = function(batchSize) 57 local current_loss = 0 58 local count = 0 59 local shuffle = torch.randperm(trainset.size) 60 batchSize = batchSize or 200 61 for t = 0, trainset.size-1, batchSize do 62 -- setup inputs and targets for batch iteration 63 local size = math.min(t + batchSize, trainset.size) - t 64 local inputs = torch.Tensor(size, 28, 28):cuda() 65 local targets = torch.Tensor(size):cuda() 66 for i = 1, size do 67 inputs[i] = trainset.data[shuffle[i+t]] 68 targets[i] = trainset.label[shuffle[i+t]] + 1 69 end 70 71 local feval = function(x_new) 72 local miniBatchSize = 20 73 if x ~= x_new then x:copy(x_new) end -- reset data 74 dl_dx:zero() 75 76 --[[ ------------------ method 1 original batch 77 local outputs = model:forward(inputs) 78 local loss = criterion:forward(outputs, targets) 79 local gradInput = criterion:backward(outputs, targets) 80 model:backward(inputs, gradInput) 81 --]] 82 83 --[[ ------------------ method 2 iter-size with batch 84 local loss = 0 85 for idx = 1, batchSize, miniBatchSize do 86 local outputs = model:forward(inputs[{{idx, idx + miniBatchSize - 1}}]) 87 loss = loss + criterion:forward(outputs, targets[{{idx, idx + miniBatchSize - 1}}]) 88 local gradInput = criterion:backward(outputs, targets[{{idx, idx + miniBatchSize - 1}}]) 89 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput) 90 end 91 dl_dx:mul(1.0 * miniBatchSize / batchSize) 92 loss = loss * miniBatchSize / batchSize 93 --]] 94 95 --[[ ------------------ method 3 mini-batch in batch 96 local outputs = torch.Tensor(batchSize, 10):zero():cuda() 97 for idx = 1, batchSize, miniBatchSize do 98 outputs[{{idx, idx + miniBatchSize - 1}}]:copy(model:forward(inputs[{{idx, idx + miniBatchSize - 1}}])) 99 end 100 local loss = 0 101 for idx = 1, batchSize, miniBatchSize do 102 loss = loss + criterion:forward(outputs[{{idx, idx + miniBatchSize - 1}}], 103 targets[{{idx, idx + miniBatchSize - 1}}]) 104 end 105 local gradInput = torch.Tensor(batchSize, 10):zero():cuda() 106 for idx = 1, batchSize, miniBatchSize do 107 gradInput[{{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward( 108 outputs[{{idx, idx + miniBatchSize - 1}}], targets[{{idx, idx + miniBatchSize - 1}}])) 109 end 110 for idx = 1, batchSize, miniBatchSize do 111 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput[{{idx, idx + miniBatchSize - 1}}]) 112 end 113 dl_dx:mul( 1.0 * miniBatchSize / batchSize) 114 loss = loss * miniBatchSize / batchSize 115 --]] 116 117 --[[ ------------------ method 4 mini-batch in batch 118 local outputs = torch.Tensor(batchSize, 10):zero():cuda() 119 local loss = 0 120 local gradInput = torch.Tensor(batchSize, 10):zero():cuda() 121 for idx = 1, batchSize, miniBatchSize do 122 outputs[{{idx, idx + miniBatchSize - 1}}]:copy(model:forward(inputs[{{idx, idx + miniBatchSize - 1}}])) 123 loss = loss + criterion:forward(outputs[{{idx, idx + miniBatchSize - 1}}], 124 targets[{{idx, idx + miniBatchSize - 1}}]) 125 gradInput[{{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward( 126 outputs[{{idx, idx + miniBatchSize - 1}}], targets[{{idx, idx + miniBatchSize - 1}}])) 127 -- end 128 -- for idx = 1, batchSize, miniBatchSize do 129 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput[{{idx, idx + miniBatchSize - 1}}]) 130 end 131 132 dl_dx:mul( 1.0 * miniBatchSize / batchSize) 133 loss = loss * miniBatchSize / batchSize 134 --]] 135 136 137 ---[[ ------------------ method 5 mini-batch in batch 138 local loss = 0 139 local gradInput = torch.Tensor(batchSize, 10):zero():cuda() 140 141 for idx = 1, batchSize, miniBatchSize do 142 local outputs = model:forward(inputs[{{idx, idx + miniBatchSize - 1}}]) 143 loss = loss + criterion:forward(outputs, targets[{{idx, idx + miniBatchSize - 1}}]) 144 gradInput[{{idx, idx + miniBatchSize - 1}}]:copy(criterion:backward(outputs, targets[{{idx, idx + miniBatchSize - 1}}])) 145 end 146 147 for idx = 1, batchSize, miniBatchSize do 148 model:forward(inputs[{{idx, idx + miniBatchSize - 1}}]) 149 --model:forward(inputs[{{batchSize - miniBatchSize + 1, batchSize}}]) 150 model:backward(inputs[{{idx, idx + miniBatchSize - 1}}], gradInput[{{idx, idx + miniBatchSize - 1}}]) 151 end 152 153 dl_dx:mul( 1.0 * miniBatchSize / batchSize) 154 loss = loss * miniBatchSize / batchSize 155 --]] 156 157 return loss, dl_dx 158 end 159 160 _, fs = optim.sgd(feval, x, sgd_params) 161 162 count = count + 1 163 current_loss = current_loss + fs[1] 164 end 165 166 return current_loss / count -- normalize loss 167 end 168 169 local eval = function(dataset, batchSize) 170 local count = 0 171 batchSize = batchSize or 200 172 173 for i = 1, dataset.size, batchSize do 174 local size = math.min(i + batchSize - 1, dataset.size) - i 175 local inputs = dataset.data[{{i,i+size-1}}]:cuda() 176 local targets = dataset.label[{{i,i+size-1}}] 177 local outputs = model:forward(inputs) 178 local _, indices = torch.max(outputs, 2) 179 indices:add(-1) 180 indices = indices:cuda() 181 local guessed_right = indices:eq(targets):sum() 182 count = count + guessed_right 183 end 184 185 return count / dataset.size 186 end 187 188 189 local max_iters = 50 190 local last_accuracy = 0 191 local decreasing = 0 192 local threshold = 1 -- how many deacreasing epochs we allow 193 for i = 1,max_iters do 194 -- timer = torch.Timer() 195 196 model:training() 197 local loss = training() 198 199 model:evaluate() 200 local accuracy = eval(validationset) 201 print(string.format('Epoch: %d Current loss: %4f; validation set accu: %4f', i, loss, accuracy)) 202 if accuracy < last_accuracy then 203 if decreasing > threshold then break end 204 decreasing = decreasing + 1 205 else 206 decreasing = 0 207 end 208 last_accuracy = accuracy 209 210 --print(' Time elapsed: ' .. i .. 'iter: ' .. timer:time().real .. ' seconds') 211 end 212 213 testset.data = testset.data:double() 214 eval(testset)
weight-init.lua

1 -- 2 -- Different weight initialization methods 3 -- 4 -- > model = require('weight-init')(model, 'heuristic') 5 -- 6 require("nn") 7 8 9 -- "Efficient backprop" 10 -- Yann Lecun, 1998 11 local function w_init_heuristic(fan_in, fan_out) 12 return math.sqrt(1/(3*fan_in)) 13 end 14 15 -- "Understanding the difficulty of training deep feedforward neural networks" 16 -- Xavier Glorot, 2010 17 local function w_init_xavier(fan_in, fan_out) 18 return math.sqrt(2/(fan_in + fan_out)) 19 end 20 21 -- "Understanding the difficulty of training deep feedforward neural networks" 22 -- Xavier Glorot, 2010 23 local function w_init_xavier_caffe(fan_in, fan_out) 24 return math.sqrt(1/fan_in) 25 end 26 27 -- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification" 28 -- Kaiming He, 2015 29 local function w_init_kaiming(fan_in, fan_out) 30 return math.sqrt(4/(fan_in + fan_out)) 31 end 32 33 local function w_init(net, arg) 34 -- choose initialization method 35 local method = nil 36 if arg == 'heuristic' then method = w_init_heuristic 37 elseif arg == 'xavier' then method = w_init_xavier 38 elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe 39 elseif arg == 'kaiming' then method = w_init_kaiming 40 else 41 assert(false) 42 end 43 44 -- loop over all convolutional modules 45 for i = 1, #net.modules do 46 local m = net.modules[i] 47 if m.__typename == 'nn.SpatialConvolution' then 48 m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) 49 elseif m.__typename == 'nn.SpatialConvolutionMM' then 50 m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) 51 elseif m.__typename == 'cudnn.SpatialConvolution' then 52 m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) 53 elseif m.__typename == 'nn.LateralConvolution' then 54 m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1)) 55 elseif m.__typename == 'nn.VerticalConvolution' then 56 m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) 57 elseif m.__typename == 'nn.HorizontalConvolution' then 58 m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) 59 elseif m.__typename == 'nn.Linear' then 60 m:reset(method(m.weight:size(2), m.weight:size(1))) 61 elseif m.__typename == 'nn.TemporalConvolution' then 62 m:reset(method(m.weight:size(2), m.weight:size(1))) 63 end 64 65 if m.bias then 66 m.bias:zero() 67 end 68 end 69 return net 70 end 71 72 return w_init

Method 1 Epoch: 1 Current loss: 0.616950; validation set accu: 0.920900 Epoch: 2 Current loss: 0.228665; validation set accu: 0.942400 Epoch: 3 Current loss: 0.168047; validation set accu: 0.957900 Epoch: 4 Current loss: 0.134796; validation set accu: 0.961800 Epoch: 5 Current loss: 0.113071; validation set accu: 0.966200 Epoch: 6 Current loss: 0.098782; validation set accu: 0.968800 Epoch: 7 Current loss: 0.088252; validation set accu: 0.970000 Epoch: 8 Current loss: 0.080225; validation set accu: 0.971200 Epoch: 9 Current loss: 0.073702; validation set accu: 0.972200 Epoch: 10 Current loss: 0.068171; validation set accu: 0.972400 method 2 Epoch: 1 Current loss: 0.624633; validation set accu: 0.922200 Epoch: 2 Current loss: 0.238459; validation set accu: 0.945200 Epoch: 3 Current loss: 0.174089; validation set accu: 0.959000 Epoch: 4 Current loss: 0.140234; validation set accu: 0.963800 Epoch: 5 Current loss: 0.116498; validation set accu: 0.968000 Epoch: 6 Current loss: 0.101376; validation set accu: 0.968800 Epoch: 7 Current loss: 0.089484; validation set accu: 0.972600 Epoch: 8 Current loss: 0.080812; validation set accu: 0.973000 Epoch: 9 Current loss: 0.073929; validation set accu: 0.975100 Epoch: 10 Current loss: 0.068330; validation set accu: 0.975400 method 3 Epoch: 1 Current loss: 2.202240; validation set accu: 0.548500 Epoch: 2 Current loss: 2.049710; validation set accu: 0.669300 Epoch: 3 Current loss: 1.993560; validation set accu: 0.728900 Epoch: 4 Current loss: 1.959818; validation set accu: 0.774500 Epoch: 5 Current loss: 1.945992; validation set accu: 0.757600 Epoch: 6 Current loss: 1.930599; validation set accu: 0.809600 Epoch: 7 Current loss: 1.911803; validation set accu: 0.837200 Epoch: 8 Current loss: 1.904754; validation set accu: 0.842100 Epoch: 9 Current loss: 1.903705; validation set accu: 0.846400 Epoch: 10 Current loss: 1.903911; validation set accu: 0.848100 method 4 Epoch: 1 Current loss: 0.624240; validation set accu: 0.924900 Epoch: 2 Current loss: 0.213469; validation set accu: 0.948500 Epoch: 3 Current loss: 0.156797; validation set accu: 0.959800 Epoch: 4 Current loss: 0.126438; validation set accu: 0.963900 Epoch: 5 Current loss: 0.106664; validation set accu: 0.965900 Epoch: 6 Current loss: 0.094166; validation set accu: 0.967200 Epoch: 7 Current loss: 0.084848; validation set accu: 0.971200 Epoch: 8 Current loss: 0.077244; validation set accu: 0.971800 Epoch: 9 Current loss: 0.071417; validation set accu: 0.973300 Epoch: 10 Current loss: 0.065737; validation set accu: 0.971600 取消注釋 Epoch: 1 Current loss: 2.178319; validation set accu: 0.542200 Epoch: 2 Current loss: 2.031493; validation set accu: 0.648700 Epoch: 3 Current loss: 1.982282; validation set accu: 0.703700 Epoch: 4 Current loss: 1.956709; validation set accu: 0.762700 Epoch: 5 Current loss: 1.927590; validation set accu: 0.808100 Epoch: 6 Current loss: 1.924535; validation set accu: 0.817200 Epoch: 7 Current loss: 1.911364; validation set accu: 0.820100 Epoch: 8 Current loss: 1.898206; validation set accu: 0.855400 Epoch: 9 Current loss: 1.885394; validation set accu: 0.836500 Epoch: 10 Current loss: 1.880787; validation set accu: 0.870200 method 5 Epoch: 1 Current loss: 0.619814; validation set accu: 0.924300 Epoch: 2 Current loss: 0.232870; validation set accu: 0.948800 Epoch: 3 Current loss: 0.172606; validation set accu: 0.954900 Epoch: 4 Current loss: 0.137763; validation set accu: 0.961800 Epoch: 5 Current loss: 0.116268; validation set accu: 0.967700 Epoch: 6 Current loss: 0.101985; validation set accu: 0.968800 Epoch: 7 Current loss: 0.091154; validation set accu: 0.970900 Epoch: 8 Current loss: 0.083219; validation set accu: 0.972700 Epoch: 9 Current loss: 0.074921; validation set accu: 0.972800 Epoch: 10 Current loss: 0.070208; validation set accu: 0.972800 取消注釋,同時注釋上面一行 Epoch: 1 Current loss: 2.161032; validation set accu: 0.497500 Epoch: 2 Current loss: 2.027255; validation set accu: 0.690900 Epoch: 3 Current loss: 1.972939; validation set accu: 0.767600 Epoch: 4 Current loss: 1.940982; validation set accu: 0.766000 Epoch: 5 Current loss: 1.933135; validation set accu: 0.812800 Epoch: 6 Current loss: 1.913039; validation set accu: 0.799300 Epoch: 7 Current loss: 1.896871; validation set accu: 0.848800 Epoch: 8 Current loss: 1.899655; validation set accu: 0.854400 Epoch: 9 Current loss: 1.889465; validation set accu: 0.845700 Epoch: 10 Current loss: 1.878703; validation set accu: 0.846400
170301更新結束
=========================================================