PyTorch Lightning for Dummies – A Tutorial and Overview
Captured source
source ↗PyTorch Lightning for Dummies - A Tutorial and Overview - Lightning AI Lightning AI Studios: Never set up a local environment again →
Takeaways
You’ll learn to use PyTorch Lightning’s Core API features by completing an applied project to train a Language Transformer written in PyTorch on the WikiText2 dataset.
The code in this tutorial is available on GitHub in the text-lab repo. Clone the repo and follow along! Introduction Training deep learning models at scale is an incredibly interesting and complex task. Reproducibility for projects is key, and reproducible code bases are exactly what we get when we leverage PyTorch Lightning for training and finetuning. An added benefit of using PyTorch Lightning is that the framework is domain agnostic and is complementary to PyTorch. Meaning – it does not replace PyTorch and we are enabled to train text, vision, audio, and multimodal models using the same framework – PyTorch Lightning.
The Research Our research objective for this tutorial is to train a small language model using a Transformer on the WikiText2 dataset. Both the Transformer and the dataset are available to us in PyTorch Lightning at pytorch_lightning.demos.transformer . We’ll see later how we can pull those into our Python module or Jupyter Notebook for use in our custom LightningDataModule and LightningModule . PyTorch and PyTorch Lightning PyTorch Lightning is not a replacement for PyTorch. Rather, PyTorch Lightning is an extension – a framework used to train models that have been implemented with PyTorch. This relationship is visualized in the following snippet. import pytorch_lightning as pl
class LabModule(pl.LightningModule): def __init__(self, vocab_size: int = 33278): super().__init__() self.model = Transformer(vocab_size=vocab_size) Expand Copy When we create self.model as shown above, we often refer to self.model as the internal module. Let’s keep reading to learn how to apply this interoperability between PyTorch and PyTorch Lightning! PyTorch Lightning: The Core API Okay – time to get to it! In the next sections, we will cover how to use the Core API of PyTorch Lightning. What is the Core API? First, let’s consider how we might organize the training steps of any deep learning project sequentially according to data processing, creating a model, and then training that model on the given dataset. These key steps/attributes are exactly how the Core API is structured with LightningDataModule , LightningModule , and Trainer . LightningDataModule LightningDataModule (LDM) wraps the data phase. It takes in a custom PyTorch Dataset and DataLoader which enables Trainer to handle data during training. If needed, LDM exposes the setup and prepare_data hooks in case you need additional customization. For the training phase, the PyTorch DataLoader has to be defined as train_dataloader and val_dataloader . The following code snippet is pseudocode (an example) of how to import LightningDataModule and use it to create a custom class. import pytorch_lightning as pl
class LabDataModule(pl.LightningDataModule): def __init__(self): super().__init__() Expand Copy We will see examples of creating train_dataloader and val_dataloader methods in LDM later in this tutorial. LightningModule LightningModule is the main training interface with the previously mentioned PyTorch models referred to as ‘internal modules’. LightningModule itself is a custom torch.nn.Module that is extended with dozens of additional hooks like on_fit_start and on_fit_end. These hooks allow us better control of Trainer’s flows and enables custom behaviors by overriding these hooks. The following snippet of pseudo-code shows how to import and use LightningModule to create a custom class. import pytorch_lightning as pl
class LabModule(pl.LightningModule): def __init__(self): super().__init__() Expand Copy Trainer Trainer configures the training scope and manages the training loop with LightningModule and LightningDataModule . The simplest Trainer configuration is accomplished by setting flags like devices , accelerator , and strategy and by passing in our choice of loggers , profilers , callbacks , and plugins . import pytorch_lightning as pl
instantiate the trainer
trainer = pl.Trainer()
instantiate the datamodule
datamodule = LabDataModule()
instantiate the model
model = LabModule()
call fit to start training
trainer.fit(model=model, datamodule=datamodule) Expand Copy Rather see this explained in a video? Sebastian Raschka, our Lead AI Educator, breaks down how to get started with structuring our PyTorch Code using PyTorch Lightning.
Getting Started: Hands-on Coding Installing PyTorch Lightning First, we will need to install PyTorch Lightning. We can further understand how closely integrated PyTorch Lightning is with PyTorch during the installation process. How? Simply by calling out that using the following command in the terminal to install PyTorch Lightning also installs PyTorch into our virtual environment. So let’s go ahead and install PyTorch Lightning using the following command. pip install pytorch-lightning Expand Copy We also need to install TorchText in order to run the demo. Let’s also do that by using the following command in the terminal. pip install torchtext Expand Copy Do you need help creating a virtual environment? There’s a video for that too!
The Custom LightningDataModule The dataset we will use is WikiText2. The demo code available to us in PyTorch Lightning will automatically fetch WikiText2 for us – so there’s no need to worry about downloading the dataset from torchtext. In the example below, WikiText2 is imported as LabDataset . If you wish to do so, you can check out the code used to create the custom PyTorch Dataset in textlab.pipeline.dataset.py . However, for the purposes of this tutorial, we can ignore that implementation for now. Once again, here’s the pseudo-code for creating a LightningDataModule without adding any additional customization. import pytorch_lightning as pl
class LabDataModule(pl.LightningDataModule): def __init__(self): super().__init__() Expand Copy Creating the Custom Class Compared to the pseudo code example, we need to customize the __init__ method further in order to enable random splitting of the dataset, and let the LightningDataModule know the data source. This is where we can also provide domain-specific arguments like block_size for datasets used in text problems, or image_size for vision problems. In the code blocks...
Excerpt shown — open the source for the full document.