본문 바로가기

ML Framework

[PyTorch] torch.no_grad()와 model.eval()의 차이

👋  Intro

안녕하세요, 유블린입니다.

이 글에서는 torch.no_grad()와 model.eval()의 차이점을 다뤄보려고 합니다.

머신러닝 프레임 워크인 Pytorch로 모델 학습을 진행한 뒤 모델의 Evaluation을 진행 할 때 아래 처럼 model.eval()을 했는데  torch.no_grad()을 또 사용 하는 경우를 보셨나요? 혹은, 모델 학습은 잘 진행 했는데 Inference를 하는 중에 Memory leak이 발생한 적이 있나요? 후자의 경우에도 해결책 중 하나로 아래 코드를 제시 합니다. 왜 그럴까요? 

 

model.eval()
with torch.no_grad():
    for batch in data_loader:

 

여기서 저와 같은 의문을 가지셨다면, 잘 찾아오셨습니다. 저 또한 처음에는 별생각 없이 사용하다가, 두 함수의 차이가 문득 궁금해져서 공부해 본 내용을 공유합니다.😊


⭐️  model.eval()

그림1. model.eval() PyTorch 공식 document

그림1은 PyTorch 1.10.0 공식 문서에서 model.eval()함수의 설명을 발췌한 것입니다.

 

위 그림에서 볼 수 있듯이 eval()함수는 해당 모델의 모든 레이어가 evaluation mode에 들어가게 해줍니다. 이말은 즉, 학습할 때만 필요한 Dropout, Batchnorm등의 기능을 비활성화 시킨다는 것입니다.

 

결론적으로, model.eval() 함수는 학습할 때만 필요했던 Dropout, Batchnorm등의 기능을 비활성화 해줘서 추론할 때의 모드로 작동하도록 조정해 주는 역할을 합니다. (메모리와는 관련이 없습니다.)

 

 

⭐️ torch.no_grad()

그림2. torch.no_grad() Pytroch 공식 document

그림2는 마찬가지로 Pytorch 1.10.0 공식 문서에서 torch.no_grad() 함수의 설명을 발췌한 것입니다.

 

설명을 보면 torch.no_grad()함수는 gradient계산 context를 비활성화 해주는 역할을 한다고 합니다. Inference나 validation을 할 때는 gradient 계산을 하지 않죠. 그래서 이함수를 사용해 줌으로써 PyTorch의 autograd engine(gradient를 계산해주는 context)를 비활성화 시켜서 더이상 gradient를 트래킹하지 않게 됩니다. 따라서 필요한 메모리가 줄어들고 연산속도가 증가하게 됩니다. 

 

즉, torch.no_grad()함수는 autograd engine(gradient를 계산해주는 context)을 비활성화 시켜 필요한 메모리를 줄어주고 연산속도를 증가시키는 역할을 합니다. 

 

 

😎 차이점

위에서 각 함수에 대해 알아봤는데요, 차이점을 요약해보면 아래와 같습니다.

 

1. model.eval()는 Dropout, Batchnorm등의 기능을 비활성화 시켜 추론 모드로 조정해 주는 역할을 합니다. 메모리와는 관련이 없습니다.

2. torch.no_grad()는 autograd engine을 비활성화 시켜 필요한 메모리를 줄어주고 연산속도를 증가시키는 역할을 합니다.

3. 하지만 torch.no_grad()함수가 model.eval() 함수처럼 dropout을 비활성화 시키진 않습니다. 

4. 인트로에서 말했듯 Inference를 하던중에 memory leak이 발생하면 해결책으로 torch.no_grad()함수를 추가하는것을 제시할 수 있습니다.

 


👍 Outro

지금까지 torch.no_grad()함수와 model.eval()함수의 차이점에 대해 정리해보았습니다. 제가 공부하며 정리한 글이기 때문에 틀린 내용이 존재할 수 있습니다. 발견하신다면 피드백은 언제든지 환영입니다. 😊 많은 분들께 도움이 되길 바라며 이만 글 줄이겠습니다. 

 


References

https://pytorch.org/docs/stable/generated/torch.no_grad.html?highlight=torch%20no_grad#torch.no_grad

https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=model%20eval

https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615/2