Pytorch-Lightning基本方法介紹


文章目錄
LIGHTNINGMODULE
  Minimal Example
  一些基本方法
    Training
      Training loop
      Validation loop
      Test loop
    Inference
      Inference in research
      Inference in production
  LightningModule API(略)
LIGHTNINGMODULE
LightningModule將PyTorch代碼整理成5個部分:

  • Computations (init).
  • Train loop (training_step)
  • Validation loop (validation_step)
  • Test loop (test_step)
  • Optimizers (configure_optimizers)

Minimal Example
所需要的方法:

import pytorch_lightning as pl class LitModel(pl.LightningModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02)

使用下面的代碼進行訓練:

train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())) trainer = pl.Trainer() model = LitModel() trainer.fit(model, train_loader)

一些基本方法

Training

Training loop

使用training_step方法來增加training loop

class LitClassifier(pl.LightningModule): def __init__(self, model): super().__init__() self.model = model def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) return loss

如果需要在epoch-level進行度量,並進行記錄,可以使用*.log*方法

def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) # logs metrics for each training_step,
    # and the average across the epoch, to the progress bar and logger
    self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return loss

如果需要對每個training_step的輸出做一些操作,可以通過改寫training_epoch_end來實現

def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) preds = ... return {'loss': loss, 'other_stuff': preds} def training_epoch_end(self, training_step_outputs): for pred in training_step_outputs: # do something

如果需要對每個batch分配到不同GPU上進行訓練,可以采用training_step_end方法來實現

def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) pred = ... return {'loss': loss, 'pred': pred} def training_step_end(self, batch_parts): gpu_0_prediction = batch_parts.pred[0]['pred'] gpu_1_prediction = batch_parts.pred[1]['pred'] # do something with both outputs
    return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2

def training_epoch_end(self, training_step_outputs): for out in training_step_outputs: # do something with preds
Validation loop

增加一個validation loop,可以通過改寫LightningModule中的validation_step來實現

class LitModel(pl.LightningModule): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) self.log('val_loss', loss)

對validation進行epoch-level度量,可以通過改寫validation_epoch_end實現

def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) pred = ... return pred def validation_epoch_end(self, validation_step_outputs): for pred in validation_step_outputs: # do something with a pred

如果需要validation進行數據並行計算(多GPU),可以通過validation_step_end方法實現

def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) pred = ... return {'loss': loss, 'pred': pred} def validation_step_end(self, batch_parts): gpu_0_prediction = batch_parts.pred[0]['pred'] gpu_1_prediction = batch_parts.pred[1]['pred'] # do something with both outputs
    return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2

def validation_epoch_end(self, validation_step_outputs): for out in validation_step_outputs: # do something with preds
Test loop

增加一個test loop的過程和上面增加validation loop是相同的,唯一不同的是,只有在使用*.test()*的時候,test loop才會被調用

model = Model() trainer = Trainer() trainer.fit() # automatically loads the best weights for you
trainer.test(model)

這里,有兩種方式調用test():

# call after training
trainer = Trainer() trainer.fit(model) # automatically auto-loads the best weights
trainer.test(test_dataloaders=test_dataloader) # or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH) trainer = Trainer() trainer.test(model, test_dataloaders=test_dataloader)    

Inference

對於研究,LightningModules像系統一樣結構化

import pytorch_lightning as pl import torch from torch import nn class Autoencoder(pl.LightningModule): def __init__(self, latent_dim=2): super().__init__() self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim)) self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28)) def training_step(self, batch, batch_idx): x, _ = batch # encode
        x = x.view(x.size(0), -1) z = self.encoder(x) # decode
        recons = self.decoder(z) # reconstruction
        reconstruction_loss = nn.functional.mse_loss(recons, x) return reconstruction_loss def validation_step(self, batch, batch_idx): x, _ = batch x = x.view(x.size(0), -1) z = self.encoder(x) recons = self.decoder(z) reconstruction_loss = nn.functional.mse_loss(recons, x) self.log('val_reconstruction', reconstruction_loss) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.0002)

可以用如下方式訓練

autoencoder = Autoencoder() trainer = pl.Trainer(gpus=1) trainer.fit(autoencoder, train_dataloader, val_dataloader)

lightning inference部分的方法:

  • training_step
  • validation_step
  • test_step
  • configure_optimizers

注意到在這個例子中,train loop和val loop完全相同,我們可以重復使用這部分代碼

class Autoencoder(pl.LightningModule): def __init__(self, latent_dim=2): super().__init__() self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim)) self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28)) def training_step(self, batch, batch_idx): loss = self.shared_step(batch) return loss def validation_step(self, batch, batch_idx): loss = self.shared_step(batch) self.log('val_loss', loss) def shared_step(self, batch): x, _ = batch # encode
        x = x.view(x.size(0), -1) z = self.encoder(x) # decode
        recons = self.decoder(z) # loss
        return nn.functional.mse_loss(recons, x) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.0002)

注:我們創建了所有loop都可以使用的一個新方法shared_step,這個方法的名字可以任意取

Inference in research

如果需要進行系統推斷,可以將forward方法加入到LightningModule中

class Autoencoder(pl.LightningModule): def forward(self, x): return self.decoder(x)

在復雜系統中增加forward的優勢,使得可以進行包含inference procedure等

class Seq2Seq(pl.LightningModule): def forward(self, x): embeddings = self(x) hidden_states = self.encoder(embeddings) for h in hidden_states: # decode
 ... return decoded
Inference in production

在LightningModule中迭代不同的模型

import pytorch_lightning as pl from pytorch_lightning.metrics import functional as FM class ClassificationTask(pl.LightningModule): def __init__(self, model): super().__init__() self.model = model def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) acc = FM.accuracy(y_hat, y) # loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
        metrics = {'val_acc': acc, 'val_loss': loss} self.log_dict(metrics) return metrics def test_step(self, batch, batch_idx): metrics = self.validation_step(batch, batch_idx) metrics = {'test_acc': metrics['val_acc'], 'test_loss': metrics['val_loss']} self.log_dict(metrics) def configure_optimizers(self): return torch.optim.Adam(self.model.parameters(), lr=0.02)

然后將任意適合該task的模型傳進去

for model in [resnet50(), vgg16(), BidirectionalRNN()]: task = ClassificationTask(model) trainer = Trainer(gpus=2) trainer.fit(task, train_dataloader, val_dataloader)

tasks可以任意復雜,比如,可以實現GAN訓練,self-supervised,甚至RL

class GANTask(pl.LightningModule): def __init__(self, generator, discriminator): super().__init__() self.generator = generator self.discriminator = discriminator ...

del)

trainer = Trainer(gpus=2) trainer.fit(task, train_dataloader, val_dataloader)
tasks可以任意復雜,比如,可以實現GAN訓練,self-supervised,甚至RL ```python class GANTask(pl.LightningModule): def __init__(self, generator, discriminator): super().__init__() self.generator = generator self.discriminator = discriminator ...

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM