Heterogeneous Batching

Overview

Heterogeneous batching is a technique that’s crucial for 3D deep learning applications due to the structure of the data. Let us first reintroduce the standard batching approach and why it is distinct from heterogeneous batching.

Batching

Batching is one of the fundamental techniques in machine learning. It is important to batch data during the optimization process for several reasons, the most important ones being:

  • Memory efficiency

  • Training speed

  • Regularization

Memory efficiency

For most real-world datasets, it is impractical to load the entire dataset into memory. For very large datasets, the data might even be shardedThe process of partitioning data across multiple servers to improve performance of searching and other data operations. across different servers. In these cases, batching allows us to perform training iterations on portions of the total dataset and apply updates after the batching.

Training speed

Updating the loss after every batch as opposed to every epoch also affords faster network training. For very large datasets, batching lets us apply updates at regular intervals without having to worry about passing through the whole dataset. Gradient descent methods are first-order optimization methods, so they can take a long time to converge. Model weights are updated slowly and iteratively. Batching is a bit like sampling the dataset and applying updates as we go along.

Regularization

Generally speaking, the higher the batch size, the better the approximation of the full dataset. However, this may not lead to better performance. Sampling data adds noise to the gradient estimate, which can allow non-convex optimization techniques like gradient descent to escape local minima. The lower the batch size, the quicker the updates. Lower batch sizes add noise to the learning process, which can be a form of regularization and may actually improve the generalization of the model.

Heterogeneous batching

In most ML applications, we have the luxury of regularly shaped data. For instance, in computer vision applications, we can resize and crop our input images to a standard image size with regular grid spacing. This allows us to use standardized input and output dimensions in our model.

In 3D, our data can vary significantly in size and shape. This can be a problem with point clouds but is particularly problematic for 3D meshes, which can differ drastically in the number of vertices, faces, and edges. When our data can consist of models with as few as 100100 vertices or as many as 1000010000 vertices, how do we approach standardizing these inputs?

Types of heterogeneous batching

The solution is a variety of batching techniques called heterogeneous batching. These techniques essentially provide a way to generate batches across data types with different (i.e., heterogeneous) shapes. The PyTorch3D Meshes class implements three types of batching:

  • List batching

  • Packed batching

  • Padded batching

Let’s learn the differences between these batching methods through an example. We will load a collection of OBJ files, create a Meshes object with them, and then use each batching method to observe the differences between them.

List batching

This method does not modify or concatenate the data at all. It simply returns each object in a Python list altogether.

Get hands-on with 1400+ tech skills courses.