AI/PyTorch

[PyTorch] 파라미터 구현하기

sangwonYoon 2023. 3. 16. 01:44

PyTorch에서 모듈의 파라미터를 torch.nn.parameter.Parameter 클래스를 활용해 구현하는 방법에 대해 알아보자.


torch.nn.parameter.Parameter

  • Tensor의 하위 클래스로, 모듈의 파라미터 역할을 하는 클래스이다.
  • Tensor 클래스가 아닌, Parameter 클래스로 모듈의 파라미터를 만들어야 모듈의 파라미터 목록에 자동으로 등록된다.
import torch
from torch import nn
from torch.nn.parameter import Parameter

class MyModel(nn.Module):
	def __init__(self):
		super().__init__()
		self.W1 = torch.Tensor([5])
		self.W2 = Parameter(torch.Tensor([10]))

	def forward(self):
		pass

model = MyModel()
for (name, param) in model.named_parameters():
    print(f"Parameter name : {name}")
    print(param)
# 출력:
# Parameter name : W2
# Parameter containing:
# tensor([10.], requires_grad=True)

Tensor 클래스로 만들어진 파라미터 W1은 파라미터 목록에 등록되지 않은것을 볼 수 있다.

 

위와 같이 파라미터 목록에 등록되면

  • optimizer에 의해 값이 자동으로 최적화될 수 있다.
optimizer.step()
  • 모델이 저장될 때 값이 함께 저장된다.

 

일반적으로 Custom 모델을 만들때 torch.nn에 구현된 layer를 가져다 쓰기 때문에 Parameter 클래스를 직접 다룰 일은 거의 없다.