AI/PyTorch

[PyTorch] Transfer Learning (전이 학습)

sangwonYoon 2023. 3. 23. 00:36

Transfer Learning

대량의 데이터 셋이 학습되어 있는 pre-trained model에 현재 데이터를 학습시켜 사용하는 방식이다.

freeze

  • pre-trained model을 가져와 일부 layer만 학습시키고 싶은 경우에 사용한다.
  • 학습시키지 않을 layer를 freeze하여 파라미터 값이 업데이트되지 않게 한다.

 

from torch import nn
from torchvision import models

class MyNewNet(nn.Module):   
    def __init__(self):
        super(MyNewNet, self).__init__()
        # pre-train된 vgg19 모델 불러오기
        self.vgg19 = models.vgg19(pretrained=True)
        self.linear_layers = nn.Linear(1000, 1)
   
    def forward(self, x):
        x = self.vgg19(x)
        # 모델의 마지막 layer에 linear layer 추가        
        return self.linear_layers(x)

my_model = MyNewNet()
my_model = my_model.to(device)

# 마지막 linear layer를 제외하고 freeze
for param in my_model.parameters():
    param.requires_grad = False

for param in my_model.linear_layers.parameters():
    param.requires_grad = True