AI/PyTorch

[PyTorch] hook

sangwonYoon 2023. 3. 17. 00:33

Hook

hook이란, 프로그래머가 본인의 패키지에 사용자가 custom 코드를 실행시킬 수 있도록 만든 인터페이스이다.

def program_A(x):
    print('program A processing!')
    return x + 3

def program_B(x):
    print('program B processing!')
    return x - 3

class Package(object):
    """프로그램 A와 B를 묶어놓은 패키지 코드"""
    def __init__(self):
        self.programs = [program_A, program_B]
        self.hooks = []

    def __call__(self, x):
        for program in self.programs:
            x = program(x)

            # Package를 사용하는 사람이 자신만의 custom program을
            # 등록할 수 있도록 미리 만들어놓은 인터페이스 hook
            if self.hooks:
                for hook in self.hooks:
                    output = hook(x)

                    # return 값이 있는 hook의 경우에만 x를 업데이트 한다
                    if output:
                        x = output

        return x
# Hook - 프로그램의 실행 로직 분석 사용 예시
def hook_analysis(x):
    print(f'hook for analysis, current value is {x}')

# 패키지 생성
package = Package()

# 생성된 패키지에 hook 추가
package.hooks = []
package.hooks.append(hook_analysis)

# 패키지 실행
input = 3
output = package(input)

# 패키지 결과
print(f"Package Process Result! [ input {input} ] [ output {output} ]")
# 출력:
# program A processing!
# hook for analysis, current value is 6
# program B processing!
# hook for analysis, current value is 3
# Package Process Result! [ input 3 ] [ output 3 ]

위와 같이 사용자가 Package 클래스를 가져다가 사용할 때, hook을 사용해 사용자가 원하는 기능을 추가할 수 있다.


Module hook

Module에 hook을 적용할 수 있는 위치는 크게 3곳이 있다.

  1. forward 실행 직전
  2. forward 실행 직후
  3. backward 실행 직후

 

register_forward_pre_hook

  • foward 실행 직전에 실행될 hook을 등록하는 함수이다.
  • hook은 아래와 같은 형태를 가져야 한다.
hook(module, input) → None or modified input
  • hook이 modified input을 반환하면 input이 수정된다.
import torch
from torch import nn


class Add(nn.Module):
    def __init__(self):
        super().__init__() 

    def forward(self, x1, x2):
        output = torch.add(x1, x2)

        return output

# 모델 생성
add = Add()

answer = []


# pre_hook를 이용해서 x1, x2 값을 알아내 answer에 저장한다.
def pre_hook(module, input):
    answer.extend(input)

add.register_forward_pre_hook(pre_hook)

 

register_forward_hook

  • forward 실행 직후 실행할 hook을 등록하는 함수이다.
  • hook은 아래와 같은 형태를 가져야 한다.
hook(module, input, output) -> None or modified output
  • hook이 modified output을 반환하면 output이 수정된다.
import torch
from torch import nn


class Add(nn.Module):
    def __init__(self):
        super().__init__() 

    def forward(self, x1, x2):
        output = torch.add(x1, x2)

        return output

# 모델 생성
add = Add()

# hook를 이용해서 전파되는 output 값에 5를 더한다.
def hook(module, input, output):
    return output+5

add.register_forward_hook(hook)

 

register_full_backward_hook

  • backward 실행 직후 실행할 hook을 등록하는 함수이다.
  • hook은 아래와 같은 형태를 가져야 한다.
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
  • grad_input과 grad_output은 각각 input과 output의 gradient들을 포함한 튜플이다.
  • Tensor를 원소로 갖는 튜플을 반환하여 input의 gradient 값을 수정할 수 있다.
import torch
from torch import nn
from torch.nn.parameter import Parameter

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = Parameter(torch.Tensor([5]))

    def forward(self, x1, x2):
        output = x1 * x2
        output = output * self.W

        return output

# 모델 생성
model = Model()

# hook를 이용해서 module의 gradient 출력의 합이 1이 되도록 수정한다.
#        ex) (1.5, 0.5) -> (0.75, 0.25)
def module_hook(module, grad_input, grad_output):
    total = grad_input[0] + grad_input[1]
    return (torch.tensor(grad_input[0] / total), torch.tensor(grad_input[1] / total))

model.register_full_backward_hook(module_hook)

 

Tensor hook

Tensor에는 Module과는 다르게, backward 직후에만 hook을 실행할 수 있다.

import torch

tensor = torch.rand(1, requires_grad=True)

def tensor_hook(grad):
    pass

tensor.register_hook(tensor_hook)

print(tensor._backward_hooks)
# 출력: OrderedDict([(0, <function __main__.tensor_hook(grad)>)])

'AI > PyTorch' 카테고리의 다른 글

[PyTorch] Autograd  (0) 2023.03.20
[PyTorch] 모듈들에 custom 함수 적용시키기  (0) 2023.03.17
[Pytorch] Dataset과 DataLoader  (0) 2023.03.16
[PyTorch] 파라미터 구현하기  (0) 2023.03.16
[PyTorch] 신경망 모델과 torch.nn.module  (0) 2023.03.16