AI/기타

Huggingface Trainer compute_metrics 파헤치기

sangwonYoon 2023. 6. 21. 02:11

Huggingface trainer로 모델을 학습시킬 때 학습중인 모델의 성능을 평가하기 위해서 반드시 필요한 compute_metrics는 어떤 방식으로 동작하고, 어떻게 구현해야 하는지 알아보자.

compute_metrics는 transformers 라이브러리의 Trainer 클래스의 객체를 생성할 때, 매개변수로 입력 받는다.

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    compute_metrics=compute_metrics
)

 

Huggingface의 trainer에 익숙하지 않다면, 아래 글을 읽는 것을 추천한다.

 

HuggingFace Trainer로 모델 학습 및 평가하기

HuggingFace Trainer는 PyTorch와 같은 딥러닝 프레임워크에서 모델을 학습하기 위한 편리한 인터페이스를 제공하는 파이썬 라이브러리이다. Trainer 클래스와 TrainingArguments 클래스는 Trainer를 사용하여

sangwonyoon.tistory.com

 

Huggingface 공식 문서에서 compute_metrics는 다음과 같이 설명되어 있다.

compute_metrics (Callable[[EvalPrediction], Dict]) — The function that will be used to compute metrics at evaluation. Must take a EvalPrediction and return a dictionary string to metric values.

즉, evaluation 시점에서 모델의 성능을 나타내는 지표를 계산하기 위한 함수이다. 입력으로는 EvalPrediction 클래스를 입력받고, string을 key로 갖고, 성능 지표 값을 value로 갖는 dictionary를 반환한다.

 

compute_metrics에 대해 더 깊게 알아보기 위해 직접 라이브러리의 소스 코드를 확인하고, 디버깅하면서 알아낸 내용들을 작성해보려고 한다.

이번 포스팅에서 알아볼 내용은 다음과 같다.

  • compute_metrics는 언제 실행되고, 어떤 방식으로 동작하는가
  • compute_metrics의 입력과 출력은 구체적으로 어떤 형태인가

 

compute_metrics는 언제 실행되고, 어떤 방식으로 동작하는가

공식 문서에서는 compute_metrics 함수가 evaluation 시점에서 동작한다고 쓰여있다. 그렇다면 trainer.evaluate()를 통해서 실행되는 것인지, 아니면 trainer.train()에서도 실행되는지 확인하기 위해 transformers 라이브러리의 Trainer 클래스가 정의되어 있는 trainer.py의 소스 코드를 확인했다.

# transformers/trainer.py

class Trainer:
    def __init__(
        self,
        ...
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        ...
    ):
        ...
        self.compute_metrics = compute_metrics

__init__ 함수에서는 compute_metrics를 입력받아 클래스의 attribute로 선언한다.

 

이후, trainer.train()을 호출할 때, compute_metrics 함수가 사용되는지를 확인하고 싶었기 때문에 먼저 train() 메소드의 소스 코드를 확인했다.

train() 메소드는 상당히 복잡한 구조로 되어 있다. 최대한 큰 맥락을 파악할 수 있도록 중요한 내용 위주로 작성하여 생략된 내용이 있을 수 있다. 자세한 내용을 확인하고 싶다면, 아래 GitHub 저장소의 소스 코드에서 내가 단계별로 제시하는 메소드의 이름을 검색하며 코드를 직접 확인해 보는 것을 추천한다. 
transformers/trainer.py

 

1. train() 메소드는 내부에서 _inner_training_loop() 메소드를 호출한다.

2. _inner_training_loop() 메소드에서 epoch의 마지막 step이거나, 학습이 완료된 이후, _maybe_log_save_evaluate() 메소드를 호출한다.

3. self.control.should_evaluate의 값이 True이면, evaluate() 메소드를 호출한다.

4. evaluate() 메소드는 evaluation_loop() 메소드를 호출하고, 그 결과값을 output 변수에 저장한다.

5. evaluation_loop() 메소드에서 드디어 compute_metrics가 등장한다.

def evaluation_loop(
    ...
):
    ...(model evaluation 진행)...

    if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
        if args.include_inputs_for_metrics:
            metrics = self.compute_metrics(
                EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
            )
        else:
            metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
    else:
        metrics = {}
		
    ...(metrics 값 후처리)...
		
    return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)

여기서 all_preds와 all_labels는 각각 evaluate 단계에서 예측한 값과 정답 값을 모아놓은 변수이다.
모든 eval data에 대한 evaluate가 끝난 뒤, compute_metrics 함수와 all_predsall_labels가 모두 None이 아니면 우리가 정의한 compute_metrics 함수에 EvalPrediction 클래스를 인자로 전달한다.

6. compute_metrics 함수에서 계산한 성능 지표 값을 반환한다.

7. 이후, EvalLoopOutput 클래스의 metrics에 compute_metrics 함수에서 계산한 성능 지표 값을 담아 반환한다.

8. evaluate() 메소드에서 compute_metrics 함수의 반환 값을 log() 메소드에 전달하여 터미널, WandB 등 여러 곳에서 성능 지표 값을 확인할 수 있도록 logging한다.

9. evaluate() 메소드는 최종적으로 compute_metrics 함수의 반환 값을 반환한다.

compute_metrics가 호출되는 flow chart

trainer.evaluate()를 호출할 때는 4번에서부터 진행한다고 생각하면 된다.

결론적으로, compute_metrics 함수는 trainer.train() 호출 시, 일반적으로 매 epoch마다 compute_metrics 함수를 호출하고, 그 결과 값을 logging한다.

 

compute_metrics의 입력과 출력은 구체적으로 어떤 형태인가

evaluation_loop() 메소드의 코드에서 compute_metrics의 입력으로 EvalPrediction 클래스의 객체가 주어지는 것을 확인할 수 있었다.

if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
    if args.include_inputs_for_metrics:
        metrics = self.compute_metrics(
            EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
        )
    else:
        metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))

args.include_inputs_for_metrics에 대한 정보는 공식 문서에서 확인할 수 있다.

include_inputs_for_metrics (bool, optional, defaults to False) — Whether or not the inputs will be passed to the compute_metrics function. This is intended for metrics that need inputs, predictions and references for scoring calculation in Metric class.

즉, compute_metrics 함수에 input에 대한 정보를 전달할 것인지 아닌지를 결정하는 인자이다.

 

compute_metrics 함수에 전달되는 EvalPrediction 클래스의 객체를 더 자세히 살펴보기 위해 compute metrics 함수 안에 중단점을 찍어 디버깅을 진행했다. 학습에 사용되는 모델은 encoder-decoder 구조의 T5이고, Question Answering task에 대한 학습을 진행했다. 따라서 모델의 예측 값(predictions)와 데이터의 라벨 값(labels)는 질문에 대한 답변일 것이다.

def compute_metrics(eval_preds):
    preds, labels = eval_preds # 여기에 중단점을 찍었다.
    ...

디버그 콘솔에 eval_preds를 입력해보니, 아래와 같은 값이 담겨 있는 것을 확인할 수 있었다.

 

EvalPrediction 클래스는 predictionslabel_idsinputs 3개의 attribute를 가지고 있다.

소스 코드를 확인해 보면, 세 attribute 모두 numpy의 ndarray 타입이거나, ndarray를 원소로 갖는 튜플 타입이어야 한다. 또한, inputs는 위에서 살펴본 include_inputs_for_metrics가 False인 경우 None이다.

# transformers/trainer_utils.py

class EvalPrediction:
    def __init__(
        self,
        predictions: Union[np.ndarray, Tuple[np.ndarray]],
        label_ids: Union[np.ndarray, Tuple[np.ndarray]],
        inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
    ):
        self.predictions = predictions
        self.label_ids = label_ids
        self.inputs = inputs

 

다시 디버깅 결과로 돌아가서, eval_preds의 predictions와 label_ids가 numpy의 ndarray 타입인지 확인해 보았다.

 

다음으로, eval_preds의 predictions와 label_ids를 decode하면 내가 예상한대로 모델이 생성한 예측 값과 데이터의 라벨 값이 나오는 지 확인해 보았다.

빠른 디버깅을 위해 학습을 거의 진행하지 않아, 모델이 생성한 예측 값이 엉뚱하긴 하지만.. 예상한 결과를 확인할 수 있었다!

 

그렇다면, compute_metrics의 출력은 어떤 형태여야 할까? 공식 문서에 따르면, 문자열(평가 지표의 이름)을 key로 갖고, 평가 지표 값을 value로 갖는 dictionary 타입이라고 쓰여있다. 따라서, 평가 지표가 EM(Exact Match)와 F1 score라면, 아래와 같은 출력 값을 가질 것이다.

{"exact_match": 0.0, "f1": 0.0}

 

compute_metrics의 입출력

 

(잘못된 내용을 지적해주시거나 내용에 관한 피드백은 언제나 환영입니다!)