深入理解model.eval()与torch.no_grad()

我们用pytorch搭建神经网络经常见到model.eval()与torch.no_grad(),它们有什么区别?是怎么工作的呢?现在就让我们来探究其中的奥秘

model.eval()

  • 使用model.eval()切换到测试模式,不会更新模型的k,b参数
  • 通知dropout层和batchnorm层在train和val中间进行切换
    在train模式,dropout层会按照设定的参数p设置保留激活单元的概率(保留概率=p,比如keep_prob=0.8),batchnorm层会继续计算数据的mean和var并进行更新
    在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值
  • model.eval()不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(backprobagation)

torch.no_grad()

使用方法:

1
2
with torch.no_grad():
# 代码块
  • 用于停止autograd模块的工作,起到加速和节省显存的作用(具体行为就是停止gradient计算,从而节省了GPU算力和显存)
  • 不会影响dropout和batchnorm层的行为

model.eval()torch.no_grad()可以同时用,更加节省cpu的算力

思考

在val模式下,为什么让dropout层所有的激活单元都通过,因为train阶段的dropout层已经屏蔽掉了一些激活单元,在val模式下,让所有的激活单元都通过还能预测数据吗?
在val模式下,让所有的激活单元都通过当然能预测数据了,相当于学习时限定你每次只能选择一份资料学,考试时开卷所有资料你都带着。val模式下,虽然让所有的激活单元都通过,但是对于各个神经元的输出, 要乘上训练时的删除比例后再输出。