I've been trying to recreate the Dino V1 traning set up for a personal project. For which I've take majority of the code from this repo: https://github.com/facebookresearch/dino[dinov1 link]1
And rn I'm almost done with it except for one part in the main_dino.py file there is a function called train_one_epoch whereby in line 318 they have given:
teacher_output= teacher (images[:2]) # only the 2 global views pass through the teacher
Now I know how pytorch tensor indexing/slicing works. Hence, if images are a batch of images of a structure:
(batch size, num crops, c, h, w)
- How would doing images[:2] get you the global crops of all the images in a given batch?
- Are they processing images in batch here or is the "images" list here just a list containing multiple crops from a SINGLE input image?
Prior to the call of
train_one_epoch()
there was another modifiaction to the models, both thestudent
andteacher
models are wrapped withMultiCropWrapper
class. Just take a look at the class' docstring as follows:So this MultiCropWrapper class handles the forward passes, and it is also mentioned that it does several forward passes for different resolutions.