Masked Siamese Networks: Objective and Training

Learn how to implement the loss function and training of Masked Siamese Networks (MSNs).

Similarity metric and predictions

To train the encoder, MSNs compute a distribution based on the similarity between a set of learnable prototypes Q={q1,q2,...,qK}Q = \{q_1, q_2, ..., q_K\} (think of each prototype, qiq_i, as dd-dimensional vectors, hence QQ is a K×dK\times d matrix) and each anchor, zimz_i^m, and target view, zi+z_i^+, representation pair. The encoder is then penalized if both these distributions are different. More specifically, for an anchor representation, zimz_i^m, we compute its “prediction,” pimp_i^m , by measuring the cosine similarity with prototypes.

Get hands-on with 1300+ tech skills courses.