DINO
DINOv2 is a self-supervised pretraining method for visual transformers. It is much more complex to write and explain compared to ViT pretraining. However, DINO achieves much better performance with a given parameter, data, and compute budget.
Remember, DINO is a method for pretraining. The model that DINO is applied to is a vision transformer. The end result of DINO, just like any other pretraining method, is a really good embedding model. Of course, a DINO-pretrained model can be used for fine-tuning tasks too.
DINO-trained ViTs are particularly good as visual embedding models. Such ViTs can support linear classifiers or vector search algorithms on top of them.
DINO means self-distillation with no labels.
What happens in DINO is this:
# gs, gt student and teacher models (same models)
# C: center of augmentation
# tps, tpt: student and teacher softmax temperatures (student's is higher)
# alpha, beta: teacher and center momentum rates
gt.params = gs.params # init "teacher" from "student"
for x in dataloader:
x1, x2 = aug(x), aug(x) # Perform random augs on image x
s1, s2 = gs(x1), gs(x2) # pass x1,2 aug into "student"
t1, t2 = gt(x1), gt(x2) # pass x1,2 aug into "teacher"
loss = (H(s1, t2) + H(s2, t1))/2
loss.backward()
update(gs) # SGD
gt.params = alpha * gt.params + (1-alpha) * gs.params
C = alpha * C + (1-beta) * cat([t1,t2]).mean(dims=0)
def H(s, t):
t = t.detach() # cut off teacher gradient
s = softmax(s / tps, dim=1)
t = softmax( (t - C) / tpt, dim=1) # Center and sharpen
loss = - (t * log(s)).sum(dim=1).mean()
return loss
There’s a lot to unpack here.
First, we have two identical models at the start. We call them ‘student’ and ‘teacher’—not the best naming, but alternatives are even worse. At any rate, the teacher gt
and student gs
are the same random transformer ViT models at the start.
Next, we get an image of fixed shape from the database. ViTs need fixed-shape images. Say, this image is a 224 x 224 RGB image of a cat. This is stored in tensor x
.
aug
is a function that outputs random image augmentations. x1
and x2
are augmented versions of x
. x1
and x2
could be small, non-overlapping crops of x
even.
Now, both x1
and x2
pass into both networks, and both teacher and student use softmax at the end to create a probability distribution.
We calculate the loss for identity. We want the two models to output the same exact probability distributions—that’s what -t*log(s)
does. We penalize deviations between student and teacher, but we only penalize the student. We only update the student with update(gs)
.
The teacher just discards the gradient and instead does two extra steps. First, from the teacher’s logits we subtract C
, and use a lower temperature tpt
for softmax. This makes the teacher’s distribution more “peaky.” The student distribution is more spread out due to the higher temperature and it also doesn’t use the centering C
.
Teacher parameters are then updated with student parameters via an exponentially moving average, with gt.params = alpha * gt.params + (1-alpha) * gs.params
.
You might ask: “Tornike, why this setup—complex as it is—instead of just letting both models collapse to outputting a uniform distribution, thus minimizing the loss but resulting in a useless model?”
So, what prevents the models from agreeing on outputting a uniform distribution?
I don’t know for sure. It seems that the combination of using C
and tpt
prevents that. A nice article explains why we can avoid collapse:
Mode collapse is prevented precisely because all samples in the mini-batch cannot take on the same value after batch normalization
As far as I understand, subtracting C
from the current teacher representation is precisely what prevents the collapse. For collapse to happen, a model has to output the exact same distribution for all augmented inputs. However, if we subtract C
from all representations and also sharpen the distribution, the collapsed state becomes a really unstable parameter setup. Any slight parameter deviation from the collapsed state would cause the models to diverge into non-collapsed states. Basically, we rig the game so that approaching collapse becomes harder and harder.