- 모듈 안에는 하나 이상의 모듈들이 포함될 수 있고, 여러 모듈을 포함한 하나의 거대한 모듈을 모델이라고 부른다.
- 대부분의 torch.nn.Module의 method들은 모델에 method를 적용하면, 모델 내부의 모든 모듈들에도 적용시키는 기능을 지원한다.
- 그러나 우리가 직접 만든 custom 함수는 이 기능을 지원하지 않기 때문에, apply를 사용해 모델의 하위 모듈들에 함수를 재귀적으로 적용시킬 수 있다.
Apply
import torch
from torch import nn
from torch.nn.parameter import Parameter
# Function
class Function_A(nn.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.W = Parameter(torch.rand(1))
def forward(self, x):
return x + self.W
class Function_B(nn.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.W = Parameter(torch.rand(1))
def forward(self, x):
return x - self.W
class Function_C(nn.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.W = Parameter(torch.rand(1))
def forward(self, x):
return x * self.W
class Function_D(nn.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.W = Parameter(torch.rand(1))
def forward(self, x):
return x / self.W
# Layer
class Layer_AB(nn.Module):
def __init__(self):
super().__init__()
self.a = Function_A('plus')
self.b = Function_B('substract')
def forward(self, x):
x = self.a(x)
x = self.b(x)
return x
class Layer_CD(nn.Module):
def __init__(self):
super().__init__()
self.c = Function_C('multiply')
self.d = Function_D('divide')
def forward(self, x):
x = self.c(x)
x = self.d(x)
return x
# Model
class Model(nn.Module):
def __init__(self):
super().__init__()
self.ab = Layer_AB()
self.cd = Layer_CD()
def forward(self, x):
x = self.ab(x)
x = self.cd(x)
return x
model = Model()
- Model
- Layer_AB
- Function_A
- Function_B
- Layer_CD
- Function_C
- Function_D
- Layer_AB
의 구조를 갖고 있는 모델이다.
def print_module(module):
print(module)
print("-" * 30)
model.apply(print_module)
# 출력:
# Function_A()
# ------------------------------
# Function_B()
# ------------------------------
# Layer_AB(
# (a): Function_A()
# (b): Function_B()
# )
# ------------------------------
# Function_C()
# ------------------------------
# Function_D()
# ------------------------------
# Layer_CD(
# (c): Function_C()
# (d): Function_D()
# )
# ------------------------------
# Model(
# (ab): Layer_AB(
# (a): Function_A()
# (b): Function_B()
# )
# (cd): Layer_CD(
# (c): Function_C()
# ...
# (d): Function_D()
# )
# )
# ------------------------------
# Model(
# (ab): Layer_AB(
# (a): Function_A()
# (b): Function_B()
# )
# (cd): Layer_CD(
# (c): Function_C()
# (d): Function_D()
# )
# )
apply는 출력 결과에서 알 수 있듯이, 후위 순회 방식으로 모듈들에 함수를 적용한다.
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] 모델 저장하기 및 불러오기 (0) | 2023.03.20 |
---|---|
[PyTorch] Autograd (0) | 2023.03.20 |
[PyTorch] hook (0) | 2023.03.17 |
[Pytorch] Dataset과 DataLoader (0) | 2023.03.16 |
[PyTorch] 파라미터 구현하기 (0) | 2023.03.16 |