What is hook?
다음 링크에 들어가면 hook은 패키지화된 코드에서 프로그래머가 customized code를 삽입할 수 있도록 해주는 하나의 인터페이스 또는 하나의 장소라고 한다. 예를 들어 프로그램의 실행 로직을 분석하거나 추가적인 기능을 제공하고 싶을 때 사용한다고 한다.
다음 코드에서 Package class 내에 있는 self.hooks가 바로 위에서 설명한 hook이다.
def program(x):
print('program processing!')
return x + 3
class Package(object):
def __init__(self):
self.program = program
self.hooks = []
def __call__(self, x):
x = program(x)
if self.hooks:
for hook in self.hooks:
output = hook(x)
if output:
x = output
return x
# Create Package
package = Package()
# Run
output = package(3)
현재로서는 self.hooks는 비어있기때문에 기존에 정의한 program 함수만 실행이 되지만 자신의 custom 코드를 다음과 같이 self.hooks에 추가하게되면 추가된 custom코드로 Package 내부에서 실행할 수 있게된다.
def hook_analysis(x):
print(f'hook for analysis, current value is {x}')
# hook 추가
package.hooks = []
package.hooks.append(hook_analysis)
output = package(3)
다른 예시로 다음과 같이 기능을 추가할 수도 있다.
def hook_multiply(x):
print('hook for multiply')
return x * 3
package.hooks = []
package.hooks.append(hook_multiply)
output = package(3)
실행되는 순서는 "기존의 program -> 추가한 custom 코드"이다. 하지만 기존 program 이전과 이후에 custom 코드를 추가하고 싶다면 다음과 같이 작성할 수 있다.
def program(x):
print('program processing!')
return x + 3
class Package(object):
def __init__(self):
self.program = program
# hooks
self.pre_hooks = []
self.hooks = []
def __call__(self, x):
# pre_hook
if self.pre_hooks:
for hook in self.pre_hooks:
output = hook(x)
if output:
x = output
x = program(x)
# hook
if self.hooks:
for hook in self.hooks:
output = hook(x)
if output:
x = output
return x
Pytorch Hook
register_hook은 Tensor에 등록할 때 사용하는 함수이며 다음과 같이 backward hook만 존재한다.
import torch
tensor = torch.rand(1, requires_grad=True)
def tensor_hook(grad):
pass
tensor.register_hook(tensor_hook)
tensor._backward_hooks
위 4가지 hook은 다음과 같이 nn.Module에 등록할 수 있는 hook이며 __dict__ 메소드를 통해 확인할 수 있다.
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
def module_hook(grad):
pass
model = Model()
model.register_forward_pre_hook(module_hook)
model.register_forward_hook(module_hook)
model.register_full_backward_hook(module_hook)
model.__dict__
이외에 state_dict_hooks라는 hook도 있는데 이것은 load_state_dict 함수가 실행될 때 내부적으로 사용한다고 한다.
우선 register_forward_pre_hook, register_forward_hook을 사용해보자.
import torch
from torch import nn
class Add(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.add(x1, x2)
return output
add = Add()
answer = []
# -- register_forward_pre_hook
def pre_hook(module, input):
answer.extend(input)
add.register_forward_pre_hook(pre_hook)
# -- register_forward_hook
def hook(module, input, output):
answer.extend(output)
add.register_forward_hook(hook)
x1 = torch.rand(1)
x2 = torch.rand(1)
output = add(x1, x2)
register_forward_pre_hook과 register_forward_hook에 등록되는 hook의 parameter들을 보면 "pre_hook"의 경우 프로그램이 실행되기 전 입력되는 값만을 받고 "hook"은 프로그램 실행 후 입력되는 값과 출력값 두 개를 받는다.
위 코드처럼 입력값과 출력값을 받을 수 있지만 다음처럼 forward되는 값을 수정할 수도 있다.
import torch
from torch import nn
class Add(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.add(x1, x2)
return output
add = Add()
def hook(module, input, output):
return output + 5
add.register_forward_hook(hook)
x1 = torch.rand(1)
x2 = torch.rand(1)
output = add(x1, x2)
이제 register_full_backward_hook을 사용해보면 다음 코드에서 확인 할 수 있듯이 backpropagation에서의 gradient 값들을 얻을 수 있다.
import torch
from torch import nn
from torch.nn.parameter import Parameter
class Model(nn.Module):
def __init__(self):
super().__init__()
self.W = Parameter(torch.Tensor([5]))
def forward(self, x1, x2):
output = x1 * x2
output = output * self.W
return output
model = Model()
answer = []
def module_hook(module, grad_input, grad_output):
answer.extend(grad_input)
answer.extend(grad_output)
model.register_full_backward_hook(module_hook)
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)
output = model(x1, x2)
output.retain_grad() # gradient 연산을 기억
output.backward()
retain_grad()는 계산된 gradient 값을 기억할 수 있게하는 함수이다. 실행하지않는다면 저장되지않아 grad_output 값이None으로 나온다.
위에서처럼 module을 기준으로 gradient 값을 알아낼 수 있지만 tensor를 기준으로 알아내려면 register_hook을 사용하면된다.
import torch
from torch import nn
from torch.nn.parameter import Parameter
class Model(nn.Module):
def __init__(self):
super().__init__()
self.W = Parameter(torch.Tensor([5]))
def forward(self, x1, x2):
output = x1 * x2
output = output * self.W
return output
model = Model()
answer = []
def tensor_hook(grad):
answer.extend(grad)
model.W.register_hook(tensor_hook)
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)
output = model(x1, x2)
output.backward()
또한 backward hook도 gradient 값을 다음과 같이 수정할 수 있다.
import torch
from torch import nn
from torch.nn.parameter import Parameter
class Model(nn.Module):
def __init__(self):
super().__init__()
self.W = Parameter(torch.Tensor([5]))
def forward(self, x1, x2):
output = x1 * x2
output = output * self.W
return output
model = Model()
def module_hook(module, grad_input, grad_output):
x1_grad, x2_grad = grad_input
total_grad = x1_grad + x2_grad
return x1_grad/total_grad, x2_grad/total_grad
model.register_full_backward_hook(module_hook)
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)
output = model(x1, x2)
output.backward()
이외에도 gradient clipping을 간단하게 수행할 수 있다. gradient clipping은 exploding gradient를 해결하는데 잘 알려진 방법이고 pytorch에서도 자체적으로 이에 대한 메소드가 존재한다고 한다. 하지만 hook을 사용하여 다음과 같이 clipping 할 수 있다.
import torch
from torchvision.models import resnet50
def gradient_clipper(model, val):
for parameter in model.parameters():
parameter.register_hook(lambda grad: grad.clamp_(-val, val))
return model
clipped_resnet = gradient_clipper(resnet50(), 0.01)
pred = clipped_resnet(dummy_input)
loss = pred.log().mean()
loss.backward()
print(clipped_resnet.fc.bias.grad[:25])
[reference]
https://medium.com/the-dl/how-to-use-pytorch-hooks-5041d777f904
https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/
https://discuss.pytorch.org/t/invoking-time-of-nn-module-register-state-dict-hook/108163
'Temp' 카테고리의 다른 글
[JSON] json.loads() -> Expecting value: line 1 column 2 (char 1) (0) | 2022.03.28 |
---|---|
[Git] Pull Request 수정 (0) | 2022.03.25 |
[Ubuntu/Linux] Change timezone (0) | 2022.02.08 |
[Git] git clone with ssh key (0) | 2022.02.07 |
[Pillow] image file is truncated (0) | 2022.01.13 |