AI/PyTorch

[PyTorch] 모듈들에 custom 함수 적용시키기

sangwonYoon 2023. 3. 17. 01:31
  • 모듈 안에는 하나 이상의 모듈들이 포함될 수 있고, 여러 모듈을 포함한 하나의 거대한 모듈을 모델이라고 부른다.
  • 대부분의 torch.nn.Module의 method들은 모델에 method를 적용하면, 모델 내부의 모든 모듈들에도 적용시키는 기능을 지원한다.
  • 그러나 우리가 직접 만든 custom 함수는 이 기능을 지원하지 않기 때문에, apply를 사용해 모델의 하위 모듈들에 함수를 재귀적으로 적용시킬 수 있다.

Apply

import torch
from torch import nn
from torch.nn.parameter import Parameter

# Function
class Function_A(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x + self.W

class Function_B(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x - self.W

class Function_C(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x * self.W

class Function_D(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.W = Parameter(torch.rand(1))

    def forward(self, x):
        return x / self.W


# Layer
class Layer_AB(nn.Module):
    def __init__(self):
        super().__init__()

        self.a = Function_A('plus')
        self.b = Function_B('substract')

    def forward(self, x):
        x = self.a(x)
        x = self.b(x)

        return x

class Layer_CD(nn.Module):
    def __init__(self):
        super().__init__()

        self.c = Function_C('multiply')
        self.d = Function_D('divide')

    def forward(self, x):
        x = self.c(x)
        x = self.d(x)

        return x


# Model
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.ab = Layer_AB()
        self.cd = Layer_CD()

    def forward(self, x):
        x = self.ab(x)
        x = self.cd(x)

        return x


model = Model()
  • Model
    • Layer_AB
      • Function_A
      • Function_B
    • Layer_CD
      • Function_C
      • Function_D

의 구조를 갖고 있는 모델이다.

 

def print_module(module):
    print(module)
    print("-" * 30)

model.apply(print_module)
# 출력:
# Function_A()
# ------------------------------
# Function_B()
# ------------------------------
# Layer_AB(
#   (a): Function_A()
#   (b): Function_B()
# )
# ------------------------------
# Function_C()
# ------------------------------
# Function_D()
# ------------------------------
# Layer_CD(
#   (c): Function_C()
#   (d): Function_D()
# )
# ------------------------------
# Model(
#   (ab): Layer_AB(
#     (a): Function_A()
#     (b): Function_B()
#   )
#   (cd): Layer_CD(
#     (c): Function_C()
# ...
#     (d): Function_D()
#   )
# )
# ------------------------------
# Model(
#   (ab): Layer_AB(
#     (a): Function_A()
#     (b): Function_B()
#   )
#   (cd): Layer_CD(
#     (c): Function_C()
#     (d): Function_D()
#   )
# )

apply는 출력 결과에서 알 수 있듯이, 후위 순회 방식으로 모듈들에 함수를 적용한다.

'AI > PyTorch' 카테고리의 다른 글

[PyTorch] 모델 저장하기 및 불러오기  (0) 2023.03.20
[PyTorch] Autograd  (0) 2023.03.20
[PyTorch] hook  (0) 2023.03.17
[Pytorch] Dataset과 DataLoader  (0) 2023.03.16
[PyTorch] 파라미터 구현하기  (0) 2023.03.16