Temp

[Pytorch] How to use pytorch hooks?

ju_young 2022. 2. 20. 01:29
728x90

What is hook?

다음 링크에 들어가면 hook은 패키지화된 코드에서 프로그래머가 customized code를 삽입할 수 있도록 해주는 하나의 인터페이스 또는 하나의 장소라고 한다. 예를 들어 프로그램의 실행 로직을 분석하거나 추가적인 기능을 제공하고 싶을 때 사용한다고 한다.

 

What is hook? - Definition from WhatIs.com

 

whatis.techtarget.com

 

다음 코드에서 Package class 내에 있는 self.hooks가 바로 위에서 설명한 hook이다.

 

def program(x):
    print('program processing!')
    return x + 3

class Package(object):
    def __init__(self):
        self.program = program
        self.hooks = []

    def __call__(self, x):
        x = program(x)
        if self.hooks:
            for hook in self.hooks:
                output = hook(x)
                if output:
                    x = output
        return x

# Create Package
package = Package()

# Run
output = package(3)

 

현재로서는 self.hooks는 비어있기때문에 기존에 정의한 program 함수만 실행이 되지만 자신의 custom 코드를 다음과 같이 self.hooks에 추가하게되면 추가된 custom코드로 Package 내부에서 실행할 수 있게된다.

 

def hook_analysis(x):
    print(f'hook for analysis, current value is {x}')

# hook 추가
package.hooks = []
package.hooks.append(hook_analysis)

output = package(3)

 

다른 예시로 다음과 같이 기능을 추가할 수도 있다.

 

def hook_multiply(x):
    print('hook for multiply')
    return x * 3

package.hooks = []
package.hooks.append(hook_multiply)

output = package(3)

 

실행되는 순서는 "기존의 program -> 추가한 custom 코드"이다. 하지만 기존 program 이전과 이후에 custom 코드를 추가하고 싶다면 다음과 같이 작성할 수 있다.

 

def program(x):
    print('program processing!')
    return x + 3

class Package(object):
    def __init__(self):
        self.program = program

        # hooks
        self.pre_hooks = []
        self.hooks = []

    def __call__(self, x):
        # pre_hook
        if self.pre_hooks:
            for hook in self.pre_hooks:
                output = hook(x)
                if output:
                    x = output

        x = program(x)

        # hook
        if self.hooks:
            for hook in self.hooks:
                output = hook(x)
                if output:
                    x = output

        return x

 

Pytorch Hook

 

register_hook은 Tensor에 등록할 때 사용하는 함수이며 다음과 같이 backward hook만 존재한다.

 

import torch
tensor = torch.rand(1, requires_grad=True)

def tensor_hook(grad):
    pass

tensor.register_hook(tensor_hook)
tensor._backward_hooks

 

위 4가지 hook은 다음과 같이 nn.Module에 등록할 수 있는 hook이며 __dict__ 메소드를 통해 확인할 수 있다.

 

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()

def module_hook(grad):
    pass

model = Model()
model.register_forward_pre_hook(module_hook)
model.register_forward_hook(module_hook)
model.register_full_backward_hook(module_hook)

model.__dict__

 

이외에 state_dict_hooks라는 hook도 있는데 이것은 load_state_dict 함수가 실행될 때 내부적으로 사용한다고 한다.

우선 register_forward_pre_hook, register_forward_hook을 사용해보자.

 

import torch
from torch import nn
 
class Add(nn.Module):
    def __init__(self):
        super().__init__() 

    def forward(self, x1, x2):
        output = torch.add(x1, x2)
        return output

add = Add()

answer = []

# -- register_forward_pre_hook
def pre_hook(module, input):
    answer.extend(input)

add.register_forward_pre_hook(pre_hook)

# -- register_forward_hook
def hook(module, input, output):
    answer.extend(output)

add.register_forward_hook(hook)

x1 = torch.rand(1)
x2 = torch.rand(1)

output = add(x1, x2)

 

register_forward_pre_hook과 register_forward_hook에 등록되는 hook의 parameter들을 보면 "pre_hook"의 경우 프로그램이 실행되기 전 입력되는 값만을 받고 "hook"은 프로그램 실행 후 입력되는 값과 출력값 두 개를 받는다.

 

위 코드처럼 입력값과 출력값을 받을 수 있지만 다음처럼 forward되는 값을 수정할 수도 있다.

 

import torch
from torch import nn

class Add(nn.Module):
    def __init__(self):
        super().__init__() 

    def forward(self, x1, x2):
        output = torch.add(x1, x2)
        return output

add = Add()

def hook(module, input, output):
    return output + 5

add.register_forward_hook(hook)

x1 = torch.rand(1)
x2 = torch.rand(1)

output = add(x1, x2)

 

이제 register_full_backward_hook을 사용해보면 다음 코드에서 확인 할 수 있듯이 backpropagation에서의 gradient 값들을 얻을 수 있다.

 

import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = Parameter(torch.Tensor([5]))

    def forward(self, x1, x2):
        output = x1 * x2
        output = output * self.W
        return output

model = Model()

answer = []

def module_hook(module, grad_input, grad_output):
    answer.extend(grad_input)
    answer.extend(grad_output)

model.register_full_backward_hook(module_hook)

x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)

output = model(x1, x2)
output.retain_grad() # gradient 연산을 기억
output.backward()

 

retain_grad()는 계산된 gradient 값을 기억할 수 있게하는 함수이다. 실행하지않는다면 저장되지않아 grad_output 값이None으로 나온다.

 

위에서처럼 module을 기준으로 gradient 값을 알아낼 수 있지만 tensor를 기준으로 알아내려면 register_hook을 사용하면된다.

 

import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = Parameter(torch.Tensor([5]))

    def forward(self, x1, x2):
        output = x1 * x2
        output = output * self.W
        return output

model = Model()
answer = []

def tensor_hook(grad):
    answer.extend(grad)

model.W.register_hook(tensor_hook)

x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)

output = model(x1, x2)
output.backward()

 

또한 backward hook도 gradient 값을 다음과 같이 수정할 수 있다.

 

import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = Parameter(torch.Tensor([5]))

    def forward(self, x1, x2):
        output = x1 * x2
        output = output * self.W
        return output

model = Model()

def module_hook(module, grad_input, grad_output):
    x1_grad, x2_grad = grad_input
    total_grad = x1_grad + x2_grad
    return x1_grad/total_grad, x2_grad/total_grad

model.register_full_backward_hook(module_hook)

x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)

output = model(x1, x2)
output.backward()

 

이외에도 gradient clipping을 간단하게 수행할 수 있다. gradient clipping은 exploding gradient를 해결하는데 잘 알려진 방법이고 pytorch에서도 자체적으로 이에 대한 메소드가 존재한다고 한다. 하지만 hook을 사용하여 다음과 같이 clipping 할 수 있다.

 

import torch
from torchvision.models import resnet50

def gradient_clipper(model, val):
    for parameter in model.parameters():
        parameter.register_hook(lambda grad: grad.clamp_(-val, val))
    return model

clipped_resnet = gradient_clipper(resnet50(), 0.01)
pred = clipped_resnet(dummy_input)
loss = pred.log().mean()
loss.backward()

print(clipped_resnet.fc.bias.grad[:25])

 

[reference]

https://medium.com/the-dl/how-to-use-pytorch-hooks-5041d777f904

https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/

https://discuss.pytorch.org/t/invoking-time-of-nn-module-register-state-dict-hook/108163

728x90

'Temp' 카테고리의 다른 글

[JSON] json.loads() -> Expecting value: line 1 column 2 (char 1)  (0) 2022.03.28
[Git] Pull Request 수정  (0) 2022.03.25
[Ubuntu/Linux] Change timezone  (0) 2022.02.08
[Git] git clone with ssh key  (0) 2022.02.07
[Pillow] image file is truncated  (0) 2022.01.13