Hook
hook이란, 프로그래머가 본인의 패키지에 사용자가 custom 코드를 실행시킬 수 있도록 만든 인터페이스이다.
def program_A(x):
print('program A processing!')
return x + 3
def program_B(x):
print('program B processing!')
return x - 3
class Package(object):
"""프로그램 A와 B를 묶어놓은 패키지 코드"""
def __init__(self):
self.programs = [program_A, program_B]
self.hooks = []
def __call__(self, x):
for program in self.programs:
x = program(x)
# Package를 사용하는 사람이 자신만의 custom program을
# 등록할 수 있도록 미리 만들어놓은 인터페이스 hook
if self.hooks:
for hook in self.hooks:
output = hook(x)
# return 값이 있는 hook의 경우에만 x를 업데이트 한다
if output:
x = output
return x
# Hook - 프로그램의 실행 로직 분석 사용 예시
def hook_analysis(x):
print(f'hook for analysis, current value is {x}')
# 패키지 생성
package = Package()
# 생성된 패키지에 hook 추가
package.hooks = []
package.hooks.append(hook_analysis)
# 패키지 실행
input = 3
output = package(input)
# 패키지 결과
print(f"Package Process Result! [ input {input} ] [ output {output} ]")
# 출력:
# program A processing!
# hook for analysis, current value is 6
# program B processing!
# hook for analysis, current value is 3
# Package Process Result! [ input 3 ] [ output 3 ]
위와 같이 사용자가 Package 클래스를 가져다가 사용할 때, hook을 사용해 사용자가 원하는 기능을 추가할 수 있다.
Module hook
Module에 hook을 적용할 수 있는 위치는 크게 3곳이 있다.
- forward 실행 직전
- forward 실행 직후
- backward 실행 직후
register_forward_pre_hook
- foward 실행 직전에 실행될 hook을 등록하는 함수이다.
- hook은 아래와 같은 형태를 가져야 한다.
hook(module, input) → None or modified input
- hook이 modified input을 반환하면 input이 수정된다.
import torch
from torch import nn
class Add(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.add(x1, x2)
return output
# 모델 생성
add = Add()
answer = []
# pre_hook를 이용해서 x1, x2 값을 알아내 answer에 저장한다.
def pre_hook(module, input):
answer.extend(input)
add.register_forward_pre_hook(pre_hook)
register_forward_hook
- forward 실행 직후 실행할 hook을 등록하는 함수이다.
- hook은 아래와 같은 형태를 가져야 한다.
hook(module, input, output) -> None or modified output
- hook이 modified output을 반환하면 output이 수정된다.
import torch
from torch import nn
class Add(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.add(x1, x2)
return output
# 모델 생성
add = Add()
# hook를 이용해서 전파되는 output 값에 5를 더한다.
def hook(module, input, output):
return output+5
add.register_forward_hook(hook)
register_full_backward_hook
- backward 실행 직후 실행할 hook을 등록하는 함수이다.
- hook은 아래와 같은 형태를 가져야 한다.
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
- grad_input과 grad_output은 각각 input과 output의 gradient들을 포함한 튜플이다.
- Tensor를 원소로 갖는 튜플을 반환하여 input의 gradient 값을 수정할 수 있다.
import torch
from torch import nn
from torch.nn.parameter import Parameter
class Model(nn.Module):
def __init__(self):
super().__init__()
self.W = Parameter(torch.Tensor([5]))
def forward(self, x1, x2):
output = x1 * x2
output = output * self.W
return output
# 모델 생성
model = Model()
# hook를 이용해서 module의 gradient 출력의 합이 1이 되도록 수정한다.
# ex) (1.5, 0.5) -> (0.75, 0.25)
def module_hook(module, grad_input, grad_output):
total = grad_input[0] + grad_input[1]
return (torch.tensor(grad_input[0] / total), torch.tensor(grad_input[1] / total))
model.register_full_backward_hook(module_hook)
Tensor hook
Tensor에는 Module과는 다르게, backward 직후에만 hook을 실행할 수 있다.
import torch
tensor = torch.rand(1, requires_grad=True)
def tensor_hook(grad):
pass
tensor.register_hook(tensor_hook)
print(tensor._backward_hooks)
# 출력: OrderedDict([(0, <function __main__.tensor_hook(grad)>)])
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] Autograd (0) | 2023.03.20 |
---|---|
[PyTorch] 모듈들에 custom 함수 적용시키기 (0) | 2023.03.17 |
[Pytorch] Dataset과 DataLoader (0) | 2023.03.16 |
[PyTorch] 파라미터 구현하기 (0) | 2023.03.16 |
[PyTorch] 신경망 모델과 torch.nn.module (0) | 2023.03.16 |