What does model.train() do in PyTorch?

model.train() tells your model that you are training the model. This helps inform layers such as Dropout and BatchNorm, which are designed to behave differently during training and evaluation. For instance, in training mode, BatchNorm updates a moving average on each new batch; whereas, for evaluation mode, these updates are frozen.

More details:
model.train() sets the mode to train
(see source code). You can call either model.eval() or model.train(mode=False) to tell that you are testing.
It is somewhat intuitive to expect train function to train model but it does not do that. It just sets the mode.

Leave a Comment