2024 딥러닝/논문 리뷰

[논문 리뷰] AdapTable: Test-Time Adaptation for Tabular via Shift-Aware Uncertainty

융딩2 2024. 10. 30. 16:17

방법론의 목적
1) Tabular data representation은 covariate shift와 concept drift의 얽힘으로 인해 방해받는 경우가 많음

  • (a), (b)를 비교해보면, 심층 학습 모델의 표현이 image데이터에서만 라벨에 따른 클러스터 가정을 따르며, tabular 데이터에서는 그렇지 않음
  • Tabular 형식 도메인에서는 잠재적 혼란 변수 Z로 인해 입력 X에서 출력 Y로의 약한 인과관계가 발생하며, 이는 유사한 입력에 대해 매우 다른 라벨을 생성하는 경우가 많음(Grinsztajn, Oyallon, and Varoquaux 2022; Liu et al. 2023)
  • 심층 신경망이 정확하게 모델링하기 어려운 고주파 함수로 이어지며, 심층 신경망은 저주파 함수에 편향되는 경향이 있음 (Beyazit et al. 2024)

 
2) Label distribution shift 고려 안함

  • 최근의 벤치마크 연구인 TableShift(Gardner, Popovic, and Schmidt 2023)는 라벨 분포 변화가 표 형식 데이터에서 성능 저하의 주요 원인임을 강조
    • *입력 공변량 변화(X-shift), 개념 변화(Y |X-shift), 라벨 분포 변화(Y-shift)*와 모델 성능 간의 관계를 조사한 결과, 라벨 분포 변화가 성능 저하와 강한 상관관계가 있음을 발견
  • 위 그림에서 모델 예측을 시각화한 결과, 출력 라벨의 주변 분포가 소스 라벨 분포에 편향되어 있음을 관찰

 
 

주목해야할 점

Tabular data를 어떻게 표현학습 보정을 하고 + Tabular data에서 Label distribution shift를 보정하는지
 
 
방법론

Shift-Aware Uncertainty Callibrator
1) Per-Sample Temperature Scaling
  • 목적: GNN을 통해 각 열의 변화를 반영한 temperature를 산출하여, 모델이 과도하게 확신한 잘못된 예측을 조정함. 이로 인해 타겟 데이터에서 예측의 신뢰성을 높임.
  • 방법: Shift trend계산(test 배치와 소스 데이터의 차이) → GNN통한 shift trend 통합→ temperature scaling(원래 모델 예측에 결합)→불확실성 조정

 

  • 효과:
(a)Reliability Diagram Before Adaptation

: 기대된 정확도와 실제 정확도 간의 큰 차이가 보이는데, 이는 모델이 특정 신뢰도 구간에서 과신하거나 과소신뢰할 수 있음
(b) Reliability Diagram After Adaptation
: 보정은 모델의 신뢰도와 실제 정확도를 일치시키는 데 효과적이었음을 보여줌
(c) MMD vs. Avg Temperature
: 그래프에서 MMD가 증가함에 따라 평균 온도도 점진적으로 증가하는 경향을 보여, 모델이 분포 차이가 큰 경우 더 높은 온도를 설정하여 예측의 확신도를 낮추려는 보정 메커니즘이 작동하고 있음
 
 

Label Distribution Handler

공통 목적: 현재 테스트 배치에 대한 타겟 레이블 분포를 정확하게 추정하고 모델 예측을 해당 분포에 맞춰 조정하기 위함. (이는 도메인 이동 시 타겟 도메인의 실제 레이블 분포와 소스 도메인의 레이블 분포가 다를 수 있다는 점에 주목하여 모델의 예측을 조정하고, 예측 불확실성을 개선하기 위한 것)
 
1) Adjusted Prediction Probability

  • 목적: 원래 모델의 예측 확률을 보정하여 타겟레이블 분포에 맞추기 위함.
    -일반적으로 소스와 타겟 분포가 다를 때, 예측 확률을 단순히 p_t(y) /p_s(y)  로 보정할 수 있지만, 이는 잘못된 예측일 때 과도한 확신을 초래할 수 있기에 다른 조정방법 시도.
  • 방법:

 
2) Temperature Scaling

  • 목적: 샘플별로 temperature를 조정하여 불확실성 반영
  • 방법: 높은 불확실성의 샘플은 확신도를 낮추고, 낮은 불확실성의 샘플은 확신도를 높임

 

  • 효과 
    • JS Divergence 값이 낮을수록 모델이 타겟 분포를 정확하게 추정하고 있다는 것
    • AdapTable의 라벨 분포 핸들러를 통해 타겟 라벨 분포 추정이 개선되었음을 보여주며, 이는 기존 소스 분포만 사용할 때보다 타겟 분포와의 차이를 줄여주어 성능을 높일 수 있음을 의미.