Introduction to PyTorch
Learn about the features of PyTorch.
We'll cover the following
PyTorch is an open-source machine-learning platform for Python. It is specifically designed for deep learning applications, such as Convolutional Neural Networks (CNNs), Recurrent Neural Networks (RNNs), and Generative Adversarial Networks (GANs), and it includes extensive layer definitions for these applications. It has built-in tensor operations designed to be used in the same way as NumPy arrays, and they are also optimized to run on GPUs for fast computation. It provides an automatic computational graph scheme so that we won’t need to calculate derivatives by hand.
After a few years of development and improvements, PyTorch now comes with a big package of features and functionalities. Don’t worry about whether we will have to re-learn the tool; even when it is a totally new version, PyTorch has always been good at keeping its core functionality consistent. In fact, its core modules haven’t changed much since its alpha release (version 0.1.1): torch.nn
, torch.autograd
, and torch.optim
, unlike some other platforms. Now, let’s take a look at some of the features in PyTorch.
Easy switching from eager mode to graph mode
When PyTorch first caught people’s attention around a few years ago, one of its biggest advantages over other deep learning tools was its dynamic graph support. It might be the main reason people ditch their old tools and embrace PyTorch. As we might have noticed, recently, more authors of the latest deep learning papers are using PyTorch to implement their experiments.
However, it doesn’t mean that PyTorch is not fit for production environments. PyTorch provides a hybrid frontend that easily transfers code from eager mode (dynamic graph) to graph mode (static graph). We can write our code in as flexible a way as before. When we are satisfied with our code, just by changing a few lines of code in our model, it will be ready to be optimized for efficiency in graph mode. This process is accomplished by the torch.jit
compiler. JIT (Just-In-Time) compiler is designed to serialize and optimize PyTorch code into TorchScript, which can run without a Python interpreter.
This means that, now, we can easily export our model to an environment where Python is not available, or efficiency is extremely important and call our model with C++ code. Two modalities are provided to convert traditional PyTorch code to TorchScript: tracing and scripting. Tracing is perfect for directly transforming our fixed model scheme with fixed inputs to graph mode.
However, if there is any data-dependent control flow in our model (for example, RNN), scripting is designed for this type of scenario, where all possible control flow routes are converted into TorchScript. Bear in mind that, for now (at the time of writing this course), scripting still has its limitations.
Dynamic graph means that the computational graph is established each time we run our model and can be changed between different runs. It’s like everyone driving their own cars around the streets when anyone can go anywhere each time they leave their home. It’s flexible for research purposes. However, the additional resource overheads that building the graphs before each run requires cannot be overlooked. Therefore, it might be a little inefficient for production purposes. Static graph means that the computational graph has to be established before the first run, and it cannot be changed once established. It’s like everyone going to work on the bus. It’s efficient, but if the passengers want to travel to different destinations, they have to talk to the bus driver, who will then talk to the public transportation authorities. Then, the bus route can be changed the next day.
Mode switching example
Here’s an example of how to change our models to graph mode. Assume that we already have the model on a given device:
Get hands-on with 1400+ tech skills courses.