AI/PyTorch

[PyTorch] 신경망 모델과 torch.nn.module

sangwonYoon 2023. 3. 16. 01:31

PyTorch에서 제공하는 신경망 모델의 기본 클래스인 torch.nn.module 클래스에 대해 알아보자.


torch.nn.module

  • 모델의 구성 요소인 layer와 parameter를 담는 컨테이너이다.
  • function이 모여 layer를 구성하고, layer가 모여 model이 만들어진다.
  • torch.nn.module을 상속받은 클래스는 function이 될 수도 있고, layer, model이 될 수도 있다.
import torch.nn as nn

# 모델 클래스 예시
class MyModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        # 부모 클래스인 nn.Module 클래스의 생성자를 호출한다.
        super(MyModel, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.layer1(x)
        x = nn.ReLU()(x)
        x = self.layer2(x)
        return x

 

torch.nn.module을 상속받은 클래스는 __init__()과 forward 함수를 override 해야한다.

 

모듈끼리 합치기

  • ModuleList 또는 ModuleList로 여러 모듈을 하나의 객체에서 관리하기
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer_list = nn.ModuleList([nn.Linear(10, 10) for i in range(5)])
        
    def forward(self, x):
        for layer in self.layer_list:
            x = layer(x)
        return x
  • Sequential 객체에 모듈을 담아, 여러 모듈들을 순서대로 연결하기
import torch
import torch.nn as nn

# 레이어 리스트 정의
model = nn.Sequential(
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 30),
            nn.ReLU(),
            nn.Linear(30, 1)
        )

input_data = torch.randn(1, 10)

# model 안에 존재하는 모듈을 차례로 지나 결과물이 만들어진다.
output = model(input_data)
print(output.size())
# 출력: torch.Size([1, 1])