Training with an EMA (Exponential Moving Average)
Learn to train an image classification model with an exponential moving average.
We'll cover the following
The PyTorch Image Model framework supports an exponential moving average (EMA), which maintains moving averages of the trained variables by employing an exponential decay.
The implementation of an EMA is as follows:
- Add shadow copies of trained weights during initialization.
- Compute a moving average of the trained weights at each training step. It uses exponential decay for the computation.
Note: Most of the time, the value for the decay rate is approximately 1.0. A good value is typically in multiple nines, such as 0.99, 0.999, or 0.9999.
Sometimes when we apply an EMA in training, it improves the performance of the model. To perform well, architectures such as MobileNet-V3, EfficientNet, and MNASNet require the EMA smoothing weights.
Training with an EMA
Set the model-ema
flag and model-ema-decay
arguments to train with an EMA. The model-ema-decay
argument represents the decay rate for the EMA and accepts a floating-point.
Call it as follows:
python train.py /app/dataset2 --model resnet50 --num-classes 4 --model-ema --model-ema-decay 0.99
By default, the value for model-ema-decay
is 0.9998.
Note: The command above will keep 99.99% of the weights from the existing state. At each iteration, it will only update 0.01% of the new weights.
Track an EMA on the CPU
We require additional memory to train with an EMA, which may cause an out-of-memory error. We can force it to run entirely on the CPU via the model-ema-force-cpu
flag.
python train.py /app/dataset2 --model resnet50 --num-classes 4 --model-ema --model-ema-decay 0.99 --model-ema-force-cpu
Note: Training with the
model-ema-force-cpu
flag will disable the validation of the EMA weights.
Most of the best-performing models in the Pytorch Image Model framework use an EMA in the training process.
Get hands-on with 1300+ tech skills courses.