eval() train() lightning

cdekelon / 2023-05-12 / 原文

self.model.train() training_step lightning

 

https://discuss.pytorch.org/t/why-not-use-model-eval-in-training-step-method-on-lightning/122425/5

Issue: 

I guess PL authors took care of switching between eval/train mode within pre defined callbacks… But problem is that when I try to predict test data in “on_fit_end” callback without using model.eval() it gives me different result than predicting outside training routine (and of course using model.eval in advance), thats why I’m wonder If ‘on_fit_end’ callback provided with model.eval … but I guess I chose inappropriate forum

 

self.model.eval() forward

https://zhuanlan.zhihu.com/p/494060986 (good)

修复很简单,我们将model.train() 向下移动一行,让其在训练循环中。理想的模式设置是尽可能接近推理步骤,以避免忘记设置它。修正后,我们的训练过程看起来更合理,没有中间的峰值出现。

5. 关于两种模式model.train()和model.eval()的汇总:

(1) model.train()

启用 Batch Normalization 和 Dropout。

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。

model.train()作用:对BN层,保证BN层能够用到每一批数据的均值和方差,并进行计算更新;对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

(2) model.eval()

不启用 Batch Normalization 和 Dropout。

如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。

model.eval()是保证BN层直接利用之前训练阶段得到的均值和方差,即测试过程中要保证BN层的均值和方差不变;对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

(3) 何时用model.eval()

训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。在eval/test过程中,需要显示地让model调用eval(),此时模型会把BN和Dropout固定住,不会取平均,而是用训练好的值。

(4) with torch.no_grad()

无论是train() 还是eval() 模式,各层的gradient计算和存储都在进行且完全一致,只是在eval模式下不会进行反向传播。而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。若想节约算力,可在test阶段带上torch.no_grad(),示例代码:

def test(model,dataloader):
	model.eval()  # 切换到测试模式
	with torch.no_grad():  #with下内容不进行grad计算
		...

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接。

(5) model.eval()和torch.no_grad()的区别

在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:

1.主要用于通知dropout层和BN层在training和validation/test模式间切换:

在train模式下,dropout网络层会按照设定的参数p,设置保留激活单元的概率(保留概率=p)。BN层会继续计算数据的mean和var等参数并更新。

在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。

2.eval模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。

而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

 

 

validation_step rank_zero_only lightning

http://www.liuxiao.org/2020/07/pytorch-lighting-%E5%B8%B8%E8%A7%81%E9%97%AE%E9%A2%98%E6%95%B4%E7%90%86/   (FAQ)

http://www.liuxiao.org/2020/07/pytorch-lighting-%E5%B8%B8%E8%A7%81%E9%97%AE%E9%A2%98%E6%95%B4%E7%90%86/

 

 

what does validation_step do lightning

https://lightning.ai/forums/t/understanding-logging-and-validation-step-validation-epoch-end/291/2

 

torch.set_grad_enabled(True) lightning

https://github.com/Lightning-AI/lightning/issues/13948

 

https://pytorch-lightning.readthedocs.io/en/1.6.2/starter/core_guide.html  (good)

 

 

with torch.no_grad lightning

https://github.com/Lightning-AI/lightning/issues/2171

 

https://github.com/Lightning-AI/lightning/blob/10c643f162318b7fe2b4a041a1f2975468492a92/pytorch_lightning/trainer/evaluation_loop.py#L246 (code)

 

 

 

how to forward in validation_step lightning

https://lightning.ai/docs/pytorch/stable/common/lightning_module.html

 

# ...
for batch_idx, batch in enumerate(train_dataloader):
    loss = model.training_step(batch, batch_idx)
    loss.backward()
    # ...

    if validate_at_some_point:
        # disable grads + batchnorm + dropout
        torch.set_grad_enabled(False)
        model.eval()

        # ----------------- VAL LOOP ---------------
        for val_batch_idx, val_batch in enumerate(val_dataloader):
            val_out = model.validation_step(val_batch, val_batch_idx)
        # ----------------- VAL LOOP ---------------

        # enable grads + batchnorm + dropout
        torch.set_grad_enabled(True)
        model.train(