AI/PyTorch

[Pytorch] Dataset과 DataLoader

sangwonYoon 2023. 3. 16. 02:19

대용량 데이터의 일부만 메모리에 적재하여 효율적으로 학습시키기 위해 필요한 Dataset과 DataLoader에 대해 알아보자.


Dataset

  • 모델에 입력할 데이터의 형태를 정의한다.
  • torch.utils.data.Dataset을 상속받아 생성한 Custom Dataset 클래스는 __init__, __len__, __getitem__을 구현해야 한다.

 

  • __init__

CSV 파일이나 XML 파일과 같은 데이터를 불러온다.

데이터의 전처리를 수행할 수 있다.

 

  • __len__

데이터셋의 전체 개수를 반환한다.

 

  • __getitem__

데이터셋에서 주어진 인덱스에 해당하는 데이터를 반환한다.

 

import torch
import pandas as pd
from torch.utils.data import Dataset

class TitanicDataset(Dataset):
    def __init__(self, path, drop_features, train=True):
        self.data = pd.read_csv(path)
        self.data['Sex'] = self.data['Sex'].map({'male':0, 'female':1})
        self.data['Embarked'] = self.data['Embarked'].map({'S':0, 'C':1, 'Q':2})
        self.data.drop(drop_features, axis = 1, inplace = True)
        self.y = self.data.pop("Survived")
        self.X = self.data
        self.is_train = train
        self.features = self.X.columns
        self.classes = pd.unique(self.y)
        

    def __len__(self):
        len_dataset = len(self.y)
        return len_dataset

    def __getitem__(self, idx):
        X = self.X.iloc[idx].values
        if self.is_train:
          y = self.y.iloc[idx]
        else:
          y = None
        return torch.tensor(X), torch.tensor(y)

 

 

DataLoader

  • Dataset으로부터 모델에 전달할 데이터의 mini batch를 생성한다.
DataLoader(dataset, batch_size=1, shuffle=False, 
          sampler=None, batch_sampler=None, num_workers=0, 
          collate_fn=None, pin_memory=False, drop_last=False, 
          timeout=0, worker_init_fn=None)
  • batch_size

DataLoader 클래스가 생성하는 mini batch의 크기이다.

 

  • shuffle

데이터를 순서대로 사용할지, 섞어서 사용할 지 지정한다.

과적합(overfitting)을 방지하기 위해 사용된다.

 

  • sampler / batch_sampler

데이터셋에서 샘플을 뽑는 방식을 정의한다.

불균형 데이터셋의 경우에서 클래스의 비율에 맞게 데이터를 제공할 때 사용한다.

 

  • collate_fn

((x1, y1), (x2, y2))와 같은 배치 단위 데이터를 ((x1, x2), (y1, y2))와 같이 바꿀 수 있다.

def collate_fn(batch): # mini batch 단위의 데이터를 인자로 넘긴다.
    print('Original:\n', batch)
    print('-'*100)
    
    data_list, label_list = [], []
    
    for _data, _label in batch:
        data_list.append(_data)
        label_list.append(_label)
    
    print('Collated:\n', [torch.Tensor(data_list), torch.LongTensor(label_list)])
    print('-'*100)
    
    return torch.Tensor(data_list), torch.LongTensor(label_list)


print(next(iter(DataLoader(dataset_random, collate_fn=collate_fn, batch_size=4))))
# 출력: 
# Original:
#  [(tensor([0.0113]), tensor(1)), (tensor([0.2369]), tensor(0)), (tensor([0.7359]), tensor(1)), (tensor([0.4268]), tensor(2))]
# ----------------------------------------------------------------------------------------------------
# Collated:
#  [tensor([0.0113, 0.2369, 0.7359, 0.4268]), tensor([1, 0, 1, 2])]
# ----------------------------------------------------------------------------------------------------

 

  • drop_last

batch 단위로 데이터를 불러올 때, 데이터의 길이가 batch_size에 나누어 떨어지지 않으면, 마지막 batch의 길이가 달라진다.

이 때, batch의 길이가 달라지는 문제를 해결하기 위해 마지막 batch를 사용하지 않을 수 있게 지정하는 파라미터이다.

for data, label in DataLoader(dataset_random, batch_size=4):
    print(len(data))
# 출력:
# 4
# 4
# 2
for data, label in DataLoader(dataset_random, batch_size=4, drop_last=True):
    print(len(data))
# 출력:
# 4
# 4

 

  • time_out

DataLoader가 batch data를 불러오는데 주어지는 제한시간이다.