Pytorch训练时报nan-2. 规避

时间:2024-11-08 06:58:05

在这里介绍的方法是基于Pytorch 1.13的,Pytorch 2.x的用户也不想要担心,因为本教程中设置的参数在Pytorch 2.x里面已经设为默认参数,完全兼容。

# compute loss
# optimizer, model
clip_grad = 1.0 # maximum value to clip grad_norm
try:
	nn.utils.clip_grad_norm_(model.parameters(), clip_grad, norm_type=2, error_if_nonfinite=True) # 遇到nonfinite的梯度报错
	optimizer.step()
except:
	print("nan detected in grad, skip batch")
	optimizer.zero_grad()  # 所有梯度置0,保证下一个batch的正常训练
	continue  # 跳过这个batch的训练

这个代码的思想就是利用clip_grad_norm_自带的梯度检查功能在反向传播前对model的每个参数梯度进行检查,如若出现梯度异常值,则跳过batch(且不会对网络进行梯度更新)。需要的注意的是,optimizer.zero_grad()除了在本代码中出现,应该在主循环里面也另外有一个,但是此处省略了。