PyTorch Lightning工具學習


【GiantPandaCV導語】Pytorch Lightning是在Pytorch基礎上進行封裝的庫,為了讓用戶能夠脫離PyTorch一些繁瑣的細節,專注於核心代碼的構建,提供了許多實用工具,可以讓實驗更加高效。本文將介紹安裝方法、設計邏輯、轉化的例子等內容。

PyTorch Lightning中提供了以下比較方便的功能:

  • multi-GPU訓練
  • 半精度訓練
  • TPU 訓練
  • 將訓練細節進行抽象,從而可以快速迭代

Pytorch Lightning

1. 簡單介紹

PyTorch lightning 是為AI相關的專業的研究人員、研究生、博士等人群開發的。PyTorch就是William Falcon在他的博士階段創建的,目標是讓AI研究擴展性更強,忽略一些耗費時間的細節。

目前PyTorch Lightning庫已經有了一定的影響力,star已經1w+,同時有超過1千多的研究人員在一起維護這個框架。

PyTorch Lightning庫

同時PyTorch Lightning也在隨着PyTorch版本的更新也在不停迭代。

版本支持情況

官方文檔也有支持,正在不斷更新:

官方文檔

下面介紹一下如何安裝。

2. 安裝方法

Pytorch Lightning安裝非常方便,推薦使用conda環境進行安裝。

source activate you_env
pip install pytorch-lightning

或者直接用pip安裝:

pip install pytorch-lightning

或者通過conda安裝:

conda install pytorch-lightning -c conda-forge

3. Lightning的設計思想

Lightning將大部分AI相關代碼分為三個部分:

  • 研究代碼,主要是模型的結構、訓練等部分。被抽象為LightningModule類。

  • 工程代碼,這部分代碼重復性強,比如16位精度,分布式訓練。被抽象為Trainer類。

  • 非必要代碼,這部分代碼和實驗沒有直接關系,不加也可以,加上可以輔助,比如梯度檢查,log輸出等。被抽象為Callbacks類。

Lightning將研究代碼划分為以下幾個組件:

  • 模型
  • 數據處理
  • 損失函數
  • 優化器

以上四個組件都將集成到LightningModule類中,是在Module類之上進行了擴展,進行了功能性補充,比如原來優化器使用在main函數中,是一種面向過程的用法,現在集成到LightningModule中,作為一個類的方法。

4. LightningModule生命周期

這部分參考了https://zhuanlan.zhihu.com/p/120331610 和 官方文檔 https://pytorch-lightning.readthedocs.io/en/latest/trainer.html

在這個模塊中,將PyTorch代碼按照五個部分進行組織:

  • Computations(init) 初始化相關計算
  • Train Loop(training_step) 每個step中執行的代碼
  • Validation Loop(validation_step) 在一個epoch訓練完以后執行Valid
  • Test Loop(test_step) 在整個訓練完成以后執行Test
  • Optimizer(configure_optimizers) 配置優化器等

展示一個最簡代碼:

>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
...     def __init__(self):
...         super().__init__()
...         self.l1 = torch.nn.Linear(28 * 28, 10)
...
...     def forward(self, x):
...         return torch.relu(self.l1(x.view(x.size(0), -1)))
...
...     def training_step(self, batch, batch_idx):
...         x, y = batch
...         y_hat = self(x)
...         loss = F.cross_entropy(y_hat, y)
...         return loss
...
...     def configure_optimizers(self):
...         return torch.optim.Adam(self.parameters(), lr=0.02)

那么整個生命周期流程是如何組織的?

4.1 准備工作

這部分包括LightningModule的初始化、准備數據、配置優化器。每次只執行一次,相當於構造函數的作用。

  • __init__()(初始化 LightningModule )
  • prepare_data() (准備數據,包括下載數據、預處理等等)
  • configure_optimizers() (配置優化器)

4.2 測試 驗證部分

實際運行代碼前,會隨即初始化模型,然后運行一次驗證代碼,這樣可以防止在你訓練了幾個epoch之后要進行Valid的時候發現驗證部分出錯。主要測試下面幾個函數:

  • val_dataloader()
  • validation_step()
  • validation_epoch_end()

4.3 加載數據

調用以下方法進行加載數據。

  • train_dataloader()
  • val_dataloader()

4.4 訓練

  • 每個batch的訓練被稱為一個step,故先運行train_step函數。

  • 當經過多個batch, 默認49個step的訓練后,會進行驗證,運行validation_step函數。

  • 當完成一個epoch的訓練以后,會對整個epoch結果進行驗證,運行validation_epoch_end函數

  • (option)如果需要的話,可以調用測試部分代碼:

    • test_dataloader()
    • test_step()
    • test_epoch_end()

5. 示例

以MNIST為例,將PyTorch版本代碼轉為PyTorch Lightning。

5.1 PyTorch版本訓練MNIST

對於一個PyTorch的代碼來說,一般是這樣構建網絡(源碼來自PyTorch中的example庫)。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

還有兩個主要工作是構建訓練函數和測試函數。

在訓練函數中需要完成:

  • 數據獲取 data, target = data.to(device), target.to(device)
  • 清空優化器梯度 optimizer.zero_grad()
  • 前向傳播 output = model(data)
  • 計算損失函數 loss = F.nll_loss(output, target)
  • 反向傳播 loss.backward()
  • 優化器進行單次優化 optimizer.step()
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

其他部分比如數據加載、數據增廣、優化器、訓練流程都是在main中執行的,采用的是一種面向過程的方法。

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
    dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")

5.2 Lightning版本訓練MNIST

第一部分,也就是歸為研究代碼,主要是模型的結構、訓練等部分。被抽象為LightningModule類。

class LitClassifier(pl.LightningModule):
    def __init__(self, hidden_dim=128, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('valid_loss', loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--hidden_dim', type=int, default=128)
        parser.add_argument('--learning_rate', type=float, default=0.0001)
        return parser

可以看出,和PyTorch版本最大的不同之處在於多了幾個流程處理函數:

  • training_step,相當於訓練過程中處理一個batch的內容
  • validation_step,相當於驗證過程中處理一個batch的內容
  • test_step, 同上
  • configure_optimizers, 這部分用於處理optimizer和scheduler
  • add_module_specific_args代表這部分控制的是與模型相關的參數

除此以外,main函數主要有以下幾個部分:

  • args參數處理
  • data部分
  • model部分
  • 訓練部分
  • 測試部分
def cli_main():
    pl.seed_everything(1234) # 這個是用於固定seed用

    # args
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = LitClassifier.add_model_specific_args(parser)
    parser = MNISTDataModule.add_argparse_args(parser)
    args = parser.parse_args()

    # data
    dm = MNISTDataModule.from_argparse_args(args)

    # model
    model = LitClassifier(args.hidden_dim, args.learning_rate)

    # training
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model, datamodule=dm)

    result = trainer.test(model, datamodule=dm)
    pprint(result)

可以看出Lightning版本的代碼代碼量略低於PyTorch版本,但是同時將一些細節忽略了,比如訓練的具體流程直接使用fit搞定,這樣不會出現忘記清空optimizer等低級錯誤。

6. 評價

總體來說,PyTorch Lightning是一個發展迅速的框架,如同fastai、keras、ignite等二次封裝的框架一樣,雖然易用性得到了提升,讓用戶可以通過更短的代碼完成任務,但是遇到錯誤的時候,往往就需要查看API甚至涉及框架源碼才能夠解決。前者降低門檻,后者略微提升了門檻。

筆者使用這個框架大概一周了,從使用者角度來談談優缺點:

6.1 優點

  • 簡化了部分代碼,之前如果要轉到GPU上,需要用to(device)方法判斷,然后轉過去。有了PyTorch lightning的幫助,可以自動幫你處理,通過設置trainer中的gpus參數即可。
  • 提供了一些有用的工具,比如混合精度訓練、分布式訓練、Horovod
  • 代碼移植更加容易
  • API比較完善,大部分都有例子,少部分講的不夠詳細。
  • 社區還是比較活躍的,如果有問題,可以在issue中提問。
  • 實驗結果整理的比較好,將每次實驗划分為version 0-n,同時可以用tensorboard比較多個實驗,非常友好。

6.2 缺點

  • 引入了一些新的概念,進一步加大了使用者的學習成本,比如pl_bolts
  • 很多原本習慣於在Pytorch中使用的功能,在PyTorch Lightning中必須查API才能使用,比如我想用scheduler,就需要去查API,然后發現在configure_optimizers函數中實現,然后模仿demo實現,因此也帶來了一定的門檻。
  • 有些報錯比較迷,筆者曾遇到過執行的時候發現多線程出問題,比較難以排查,最后通過更改distributed_backend得到了解決。遇到新的坑要去API里找答案,如果沒有解決繼續去Issue里找答案。

7. 參考


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM