DINOv2 is a self-supervised pretraining method for visual transformers. It is much more complex to write and explain compared to ViT pretraining. However, DINO achieves much better performance with a given parameter, data, and compute budget.

Remember, DINO is a method for pretraining. The model that DINO is applied to is a vision transformer. The end result of DINO, just like any other pretraining method, is a really good embedding model. Of course, a DINO-pretrained model can be used for fine-tuning tasks too.

DINO-trained ViTs are particularly good as visual embedding models. Such ViTs can support linear classifiers or vector search algorithms on top of them.

DINO means self-distillation with no labels.

What happens in DINO is this:

# gs, gt student and teacher models (same models)
# C: center of augmentation
# tps, tpt: student and teacher softmax temperatures (student's is higher)
# alpha, beta: teacher and center momentum rates

gt.params = gs.params # init "teacher" from "student"
for x in dataloader:
   x1, x2 = aug(x), aug(x) # Perform random augs on image x

   s1, s2 = gs(x1), gs(x2) # pass x1,2 aug into "student"
   t1, t2 = gt(x1), gt(x2) # pass x1,2 aug into "teacher"

   loss = (H(s1, t2) + H(s2, t1))/2
   loss.backward()
   update(gs) # SGD

   gt.params = alpha * gt.params + (1-alpha) * gs.params
   C = alpha * C + (1-beta) * cat([t1,t2]).mean(dims=0)

def H(s, t):
  t = t.detach() # cut off teacher gradient
  s = softmax(s / tps, dim=1) 
  t = softmax( (t - C) / tpt, dim=1) # Center and sharpen
  loss = - (t * log(s)).sum(dim=1).mean()
  return loss

There’s a lot to unpack here.

First, we have two identical models at the start. We call them ‘student’ and ‘teacher’—not the best naming, but alternatives are even worse. At any rate, the teacher gt and student gs are the same random transformer ViT models at the start.

Next, we get an image of fixed shape from the database. ViTs need fixed-shape images. Say, this image is a 224 x 224 RGB image of a cat. This is stored in tensor x.

aug is a function that outputs random image augmentations. x1 and x2 are augmented versions of x. x1 and x2 could be small, non-overlapping crops of x even.

Now, both x1 and x2 pass into both networks, and both teacher and student use softmax at the end to create a probability distribution.

We calculate the loss for identity. We want the two models to output the same exact probability distributions—that’s what -t*log(s) does. We penalize deviations between student and teacher, but we only penalize the student. We only update the student with update(gs).

The teacher just discards the gradient and instead does two extra steps. First, from the teacher’s logits we subtract C, and use a lower temperature tpt for softmax. This makes the teacher’s distribution more “peaky.” The student distribution is more spread out due to the higher temperature and it also doesn’t use the centering C.

Teacher parameters are then updated with student parameters via an exponentially moving average, with gt.params = alpha * gt.params + (1-alpha) * gs.params.

You might ask: “Tornike, why this setup—complex as it is—instead of just letting both models collapse to outputting a uniform distribution, thus minimizing the loss but resulting in a useless model?”

So, what prevents the models from agreeing on outputting a uniform distribution?

I don’t know for sure. It seems that the combination of using C and tpt prevents that. A nice article explains why we can avoid collapse:

Mode collapse is prevented precisely because all samples in the mini-batch cannot take on the same value after batch normalization

As far as I understand, subtracting C from the current teacher representation is precisely what prevents the collapse. For collapse to happen, a model has to output the exact same distribution for all augmented inputs. However, if we subtract C from all representations and also sharpen the distribution, the collapsed state becomes a really unstable parameter setup. Any slight parameter deviation from the collapsed state would cause the models to diverge into non-collapsed states. Basically, we rig the game so that approaching collapse becomes harder and harder.