PyTorch에서 모듈의 파라미터를 torch.nn.parameter.Parameter 클래스를 활용해 구현하는 방법에 대해 알아보자.
torch.nn.parameter.Parameter
- Tensor의 하위 클래스로, 모듈의 파라미터 역할을 하는 클래스이다.
- Tensor 클래스가 아닌, Parameter 클래스로 모듈의 파라미터를 만들어야 모듈의 파라미터 목록에 자동으로 등록된다.
import torch
from torch import nn
from torch.nn.parameter import Parameter
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.W1 = torch.Tensor([5])
self.W2 = Parameter(torch.Tensor([10]))
def forward(self):
pass
model = MyModel()
for (name, param) in model.named_parameters():
print(f"Parameter name : {name}")
print(param)
# 출력:
# Parameter name : W2
# Parameter containing:
# tensor([10.], requires_grad=True)
Tensor 클래스로 만들어진 파라미터 W1은 파라미터 목록에 등록되지 않은것을 볼 수 있다.
위와 같이 파라미터 목록에 등록되면
- optimizer에 의해 값이 자동으로 최적화될 수 있다.
optimizer.step()
- 모델이 저장될 때 값이 함께 저장된다.
일반적으로 Custom 모델을 만들때 torch.nn에 구현된 layer를 가져다 쓰기 때문에 Parameter 클래스를 직접 다룰 일은 거의 없다.
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] hook (0) | 2023.03.17 |
---|---|
[Pytorch] Dataset과 DataLoader (0) | 2023.03.16 |
[PyTorch] 신경망 모델과 torch.nn.module (0) | 2023.03.16 |
[PyTorch] torch.reshape과 torch.Tensor.view의 차이 (0) | 2023.03.14 |
[PyTorch] torch.mm과 torch.matmul의 차이 (0) | 2023.03.14 |