【GiantPandaCV導語】Pytorch Lightning是在Pytorch基礎上進行封裝的庫,為了讓用戶能夠脫離PyTorch一些繁瑣的細節,專注於核心代碼的構建,提供了許多實用工具,可以讓實驗更加高效。本文將介紹安裝方法、設計邏輯、轉化的例子等內容。
PyTorch Lightning中提供了以下比較方便的功能:
- multi-GPU訓練
- 半精度訓練
- TPU 訓練
- 將訓練細節進行抽象,從而可以快速迭代
1. 簡單介紹
PyTorch lightning 是為AI相關的專業的研究人員、研究生、博士等人群開發的。PyTorch就是William Falcon在他的博士階段創建的,目標是讓AI研究擴展性更強,忽略一些耗費時間的細節。
目前PyTorch Lightning庫已經有了一定的影響力,star已經1w+,同時有超過1千多的研究人員在一起維護這個框架。
同時PyTorch Lightning也在隨着PyTorch版本的更新也在不停迭代。
官方文檔也有支持,正在不斷更新:
下面介紹一下如何安裝。
2. 安裝方法
Pytorch Lightning安裝非常方便,推薦使用conda環境進行安裝。
source activate you_env
pip install pytorch-lightning
或者直接用pip安裝:
pip install pytorch-lightning
或者通過conda安裝:
conda install pytorch-lightning -c conda-forge
3. Lightning的設計思想
Lightning將大部分AI相關代碼分為三個部分:
-
研究代碼,主要是模型的結構、訓練等部分。被抽象為LightningModule類。
-
工程代碼,這部分代碼重復性強,比如16位精度,分布式訓練。被抽象為Trainer類。
-
非必要代碼,這部分代碼和實驗沒有直接關系,不加也可以,加上可以輔助,比如梯度檢查,log輸出等。被抽象為Callbacks類。
Lightning將研究代碼划分為以下幾個組件:
- 模型
- 數據處理
- 損失函數
- 優化器
以上四個組件都將集成到LightningModule類中,是在Module類之上進行了擴展,進行了功能性補充,比如原來優化器使用在main函數中,是一種面向過程的用法,現在集成到LightningModule中,作為一個類的方法。
4. LightningModule生命周期
這部分參考了https://zhuanlan.zhihu.com/p/120331610 和 官方文檔 https://pytorch-lightning.readthedocs.io/en/latest/trainer.html
在這個模塊中,將PyTorch代碼按照五個部分進行組織:
- Computations(init) 初始化相關計算
- Train Loop(training_step) 每個step中執行的代碼
- Validation Loop(validation_step) 在一個epoch訓練完以后執行Valid
- Test Loop(test_step) 在整個訓練完成以后執行Test
- Optimizer(configure_optimizers) 配置優化器等
展示一個最簡代碼:
>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(28 * 28, 10)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))
...
... def training_step(self, batch, batch_idx):
... x, y = batch
... y_hat = self(x)
... loss = F.cross_entropy(y_hat, y)
... return loss
...
... def configure_optimizers(self):
... return torch.optim.Adam(self.parameters(), lr=0.02)
那么整個生命周期流程是如何組織的?
4.1 准備工作
這部分包括LightningModule的初始化、准備數據、配置優化器。每次只執行一次,相當於構造函數的作用。
__init__()
(初始化 LightningModule )prepare_data()
(准備數據,包括下載數據、預處理等等)configure_optimizers()
(配置優化器)
4.2 測試 驗證部分
實際運行代碼前,會隨即初始化模型,然后運行一次驗證代碼,這樣可以防止在你訓練了幾個epoch之后要進行Valid的時候發現驗證部分出錯。主要測試下面幾個函數:
val_dataloader()
validation_step()
validation_epoch_end()
4.3 加載數據
調用以下方法進行加載數據。
train_dataloader()
val_dataloader()
4.4 訓練
-
每個batch的訓練被稱為一個step,故先運行train_step函數。
-
當經過多個batch, 默認49個step的訓練后,會進行驗證,運行validation_step函數。
-
當完成一個epoch的訓練以后,會對整個epoch結果進行驗證,運行validation_epoch_end函數
-
(option)如果需要的話,可以調用測試部分代碼:
- test_dataloader()
- test_step()
- test_epoch_end()
5. 示例
以MNIST為例,將PyTorch版本代碼轉為PyTorch Lightning。
5.1 PyTorch版本訓練MNIST
對於一個PyTorch的代碼來說,一般是這樣構建網絡(源碼來自PyTorch中的example庫)。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
還有兩個主要工作是構建訓練函數和測試函數。
在訓練函數中需要完成:
- 數據獲取
data, target = data.to(device), target.to(device)
- 清空優化器梯度
optimizer.zero_grad()
- 前向傳播
output = model(data)
- 計算損失函數
loss = F.nll_loss(output, target)
- 反向傳播
loss.backward()
- 優化器進行單次優化
optimizer.step()
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run:
break
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
其他部分比如數據加載、數據增廣、優化器、訓練流程都是在main中執行的,采用的是一種面向過程的方法。
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('../data', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
5.2 Lightning版本訓練MNIST
第一部分,也就是歸為研究代碼,主要是模型的結構、訓練等部分。被抽象為LightningModule類。
class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser
可以看出,和PyTorch版本最大的不同之處在於多了幾個流程處理函數:
- training_step,相當於訓練過程中處理一個batch的內容
- validation_step,相當於驗證過程中處理一個batch的內容
- test_step, 同上
- configure_optimizers, 這部分用於處理optimizer和scheduler
- add_module_specific_args代表這部分控制的是與模型相關的參數
除此以外,main函數主要有以下幾個部分:
- args參數處理
- data部分
- model部分
- 訓練部分
- 測試部分
def cli_main():
pl.seed_everything(1234) # 這個是用於固定seed用
# args
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
parser = MNISTDataModule.add_argparse_args(parser)
args = parser.parse_args()
# data
dm = MNISTDataModule.from_argparse_args(args)
# model
model = LitClassifier(args.hidden_dim, args.learning_rate)
# training
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, datamodule=dm)
result = trainer.test(model, datamodule=dm)
pprint(result)
可以看出Lightning版本的代碼代碼量略低於PyTorch版本,但是同時將一些細節忽略了,比如訓練的具體流程直接使用fit搞定,這樣不會出現忘記清空optimizer等低級錯誤。
6. 評價
總體來說,PyTorch Lightning是一個發展迅速的框架,如同fastai、keras、ignite等二次封裝的框架一樣,雖然易用性得到了提升,讓用戶可以通過更短的代碼完成任務,但是遇到錯誤的時候,往往就需要查看API甚至涉及框架源碼才能夠解決。前者降低門檻,后者略微提升了門檻。
筆者使用這個框架大概一周了,從使用者角度來談談優缺點:
6.1 優點
- 簡化了部分代碼,之前如果要轉到GPU上,需要用to(device)方法判斷,然后轉過去。有了PyTorch lightning的幫助,可以自動幫你處理,通過設置trainer中的gpus參數即可。
- 提供了一些有用的工具,比如混合精度訓練、分布式訓練、Horovod
- 代碼移植更加容易
- API比較完善,大部分都有例子,少部分講的不夠詳細。
- 社區還是比較活躍的,如果有問題,可以在issue中提問。
- 實驗結果整理的比較好,將每次實驗划分為version 0-n,同時可以用tensorboard比較多個實驗,非常友好。
6.2 缺點
- 引入了一些新的概念,進一步加大了使用者的學習成本,比如pl_bolts
- 很多原本習慣於在Pytorch中使用的功能,在PyTorch Lightning中必須查API才能使用,比如我想用scheduler,就需要去查API,然后發現在configure_optimizers函數中實現,然后模仿demo實現,因此也帶來了一定的門檻。
- 有些報錯比較迷,筆者曾遇到過執行的時候發現多線程出問題,比較難以排查,最后通過更改distributed_backend得到了解決。遇到新的坑要去API里找答案,如果沒有解決繼續去Issue里找答案。