ML & DL/파이썬 머신러닝 실전 가이드

[Python ML Guide] Section 4.4(분류 Classification): Boosting의 이해와 Gradient Boosting / XGBoost / LightGBM

Jae. 2023. 9. 12. 14:29
728x90

https://www.inflearn.com/course/%ED%8C%8C%EC%9D%B4%EC%8D%AC-%EB%A8%B8%EC%8B%A0%EB%9F%AC%EB%8B%9D-%EC%99%84%EB%B2%BD%EA%B0%80%EC%9D%B4%EB%93%9C

 

 

[개정판] 파이썬 머신러닝 완벽 가이드 - 인프런 | 강의

이론 위주의 머신러닝 강좌에서 탈피하여 머신러닝의 핵심 개념을 쉽게 이해함과 동시에 실전 머신러닝 애플리케이션 구현 능력을 갖출 수 있도록 만들어 드립니다., [사진]상세한 설명과 풍부

www.inflearn.com

 

 

 


1.Boosting의 이해

 

Boosting 알고리즘은 여러 개의 약한 학습기(weak learner)를 순차적으로 학습-예측하면서 잘못 예측한 데이터나 학습 트리에 가중치 부여를 통해 오류를 개선해 나가면서 학습하는 방식

 

Boosting의 대표적인 구현은 AdaBoost(Adaptive boosting), Gradient Boost(XGBoost, LightGBM) 가 있음

 

 

Voting = 서로 다른 알고리즘을 지닌 분류기들을 결합하는 것

 

Bagging = 각각의 분류기가 모두 같은 유형의 알고리즘 기반이지만, 데이터 샘플링을 서로 다르게 가져가면서 개별적으로 학습을 수행해 최종적으로 모든 분류기가 보팅을 수행하여 예측결정을 하는 것

 

 

 

 

Boosting의 종류

  • AdaBoost
  • Gradient Boosting
  • GBM(Gradient Boost Machine)
    • XGBoost
    • LightGBM

 

 

 

GBM의 문제

  • 약한 학습기들이 순차적으로 학습하기 때문에 수행 시간이 꽤나 걸린다
  • 반면에 Random Forest는 병렬로 학습하기 때문에 시간이 덜 걸린다
  • Overfitting이 발생할 가능성이 크다

 

 

 

 


2.AdaBoosting

 

AdaBoosting의 학습 / 예측 프로세스

 

 

 


3.GBM(Gradient Boost Machine)

 

GBM(Gradient Boost Machine)도 AdaBoost와 유사하나,

가중치 업데이트를 경사 하강법(Gradient Descent)를 이용하여 최적화된 결과를 얻는 알고리즘

 

오류 값 = 실제 값(target 값 $y$) - 예측 값($y_pred$)이다.

 

분류의 실제 결과값을 $y$, 피처를 $X_{1}, X_{2}, ...... , X_{n}$ 이라 하고, 이 피처에 기반한 예측 함수를 $F(x)$라 한다면 오류식 $h(x) = y - F(x)$이 된다.

 

이 오류식 $h(x) = y - F(x)$을 최소화하는 방향성을 가지고 반복적으로 가중치 값을 업데이트 하는 것이 경사 하강법(Gradient Descent)이다.

 

loss function이 감소하도록 weak learner들을 추가시키면서 가중치들을 업데이트함

 

경사 하강법은 반복 수행을 통해 오류를 최소화할 수 있도록 가중치의 업데이트 값을 도출하는 기법으로서 머신러닝에서 중요한 기법 중 하나이다.

 

 

Scikit Learn GBM 주요 Hyperparameter 및 튜닝

 

Scikit Learn은 GBM 분류를 위해 GradientBoostingClassifier 클래스를 제공함

 

  • loss: 경사 하강법에서 사용할 loss function을 지정함. 특별한 이유가 없으면 기본값이 'deviance'를 그대로 적용함

 

  • learning_rate: GBM이 학습을 진행할 때마다 적용하는 학습률   
    • Weak learner가 순차적으로 오류 값을 보정해 나가는 데 적용하는 계수
    • 0~1 사이의 값을 지정할 수 있으며 기본값은 0.1
    • 작을수록 성능이 좋아지지만 시간이 오래걸림
    • 너무 작은 값을 적용하면
      • 업데이트 되는 값이 작아져서 최소 오류 값을 찾아 예측 성능이 높아질 가능성이 높음
      • 하지만 weak learner는 순차적인 반복이 필요해서 수행 시간이 오래 걸리고, 또 너무 작게 설정하면 모든 weak learner의 반복이 완료돼도 최소 오류 값을 찾지 못할 수 있음
    • 너무 큰 값을 적용하면
      • 최소 오류 값을 찾지 못하고 그냥 지나쳐 버려 예측 성능이 떨어질 가능성이 높아짐
      • 반면에 빠른 수행이 가능함

 

  • n_estimators: weak learner의 개수
    • weak learner가 순차적으로 오류를 보정하므로 개수가 많을수록 예측 성능이 일정 수준까지는 좋아질 수 있음
    • 하지만 개수가 많을수록 수행 시간이 오래 걸림
    • Default 값은 100

 

  • subsample: weak learner가 학습에 사용하는 데이터의 샘플링 비율
    • Default 값은 1, 이는 전체 학습 데이터를 기반으로 학습한다는 의미
    • 0.5이면 학습 데이터의 절반을 기반으로 학습한다는 의미
    • Overfitting이 염려되는 경우 subsample을 1보다 작은 값으로 설정

 

 

from sklearn.ensemble import GradientBoostingClassifier
import time
import warnings

warnings.filterwarnings("ignore")

X_train, X_test, y_train, y_test = get_human_dataset()

# GBM 수행 시간 측정을 위함. 시작 시간 설정.
start_time = time.time()

gb_clf = GradientBoostingClassifier(random_state=0)
gb_clf.fit(X_train, y_train)
gb_pred = gb_clf.predict(X_test)
gb_accuracy = accuracy_score(y_test, gb_pred)

print("GBM 정확도: {0:.4f}".format(gb_accuracy))
print("GBM 수행 시간: {0:.1f} 초 ".format(time.time() - start_time))

# GBM 정확도: 0.9393
# GBM 수행 시간: 963.5 초
### 아래는 강의에서 설명드리지는 않지만 GridSearchCV로 GBM의 하이퍼 파라미터 튜닝을 수행하는 예제 입니다. 
### 사이킷런이 1.X로 업그레이드 되며서 GBM의 학습 속도가 현저하게 저하되는 문제가 오히려 발생합니다. 
### 아래는 수행 시간이 오래 걸리므로 참고용으로만 사용하시면 좋을 것 같습니다. 
from sklearn.model_selection import GridSearchCV

params = {
    'n_estimators':[100, 500],
    'learning_rate' : [ 0.05, 0.1]
}
grid_cv = GridSearchCV(gb_clf , param_grid=params , cv=2 ,verbose=1)
grid_cv.fit(X_train , y_train)
print('최적 하이퍼 파라미터:\n', grid_cv.best_params_)
print('최고 예측 정확도: {0:.4f}'.format(grid_cv.best_score_))
# GridSearchCV를 이용하여 최적으로 학습된 estimator로 predict 수행. 
gb_pred = grid_cv.best_estimator_.predict(X_test)
gb_accuracy = accuracy_score(y_test, gb_pred)
print('GBM 정확도: {0:.4f}'.format(gb_accuracy))

 

 

 

GridSearchCV가 XGBoost, LigthGBM에 적합하지 않는 이유

 

  • Gradient Boosting 기반 알고리즘은 튜닝해야 할 하이퍼 파라미터 개수가 많고 범위가 넓어서 가능한 개별 경우의 수가 너무 많음
  • 이렇듯 경우의 수가 많은 경우 데이터가 크면 GridSearchCV로 하이퍼 파라미터 튜닝에 굉장히 오랜 시간이 투입되어야 한다

 


4.XGBoost(eXtra Gradient Boost)

 

표준 GBM이 가지고 있던 문제를 해결하기 위해 발전된 형태의 GBM이 등장

 

 

XGBoost의 주요 장점

 

  • 뛰어난 예측 성능
  • GBM 대비 빠른 수행 시간 (LightGBM이나 일부 앙상블 기법이 더 빠름)
    • CPU 병렬 처리, GPU 지원
  • 다양한 성능 향상 기능
    • 규제(Regularization) 기능 탑재 (Overfitting 방지)
    • Tree Pruning (Tree의 노드들을 재검증: 가지치기)
  • 다양한 편의 기능
    • 조기 중단(Early Stopping)
    • 자체 내장된 교차 검증
    • 결손값 자체 처리

 

 

 

XGBoost의 python 구현

 

[API] Python Wrapper XGB vs Scikit Learn Wrapper XGB

 

HyperParameter를 입력받는 시기가 다르다:

  • Python Wrapper XGB는 train()시 입력받음
  • Scikit Learn Wrapper XGB는 객체 생성시 & fit()시 입력받음

 

학습 API의 반환 여부가 다르다:

  • Python Wrapper XGB는 학습된 객체를 반환 받아야 함
  • Scikit Learn Wrapper는 fit() method만 실행시켜도 됨

 

예측 API의 반환 값이 다르다:

  • Python Wrapper XGB는 각 Label별 예측 결과 확률값을 반환
  • Scikit Learn Wrapper는 예측 결과 Label 값을 반환

 

 

Python Wrapper의 예측 API는 predict_proba()와 비슷하다

 

 

[Hyperparameter] Python Wrapper XGB  vs Scikit Learn Wrapper XGB

 

min_child_weight: weight값을 넘으면 child를 만들지 / 만들지 않을지

 

 

Overfitting 문제가 심각한 경우

 

 

 

 

규제 (Regulation): 오차를 줄이기 위해 학습을 하다 보면 overfitting이 나올 확률이 높기에 오차에 너무 집중하지 않도록 규제를 걺

 

 

 

 

 

XGBoost의 Early Stopping

 

 

XGBoost는 특정 반복 횟수 만큼 더 이상 loss function이 감소하지 않으면 지정된 반복횟수를 다 완료하지 않고 수행을 종료할 수 있음

 

  • 학습을 위한 시간을 단축시킬 수 있음. 특히 최적화 튜닝 단계에서 적절하게 사용 가능

 

  • 너무 반복 횟수를 단축할 경우 예측 성능 최적화가 안된 상태에서 학습이 종료될 수 있으므로 유의 필요

 

  • 조기 중단 설정을 위한 주요 파라미터
    • early_stopping_rounds: 더 이상 평가 지표가 감소하지 않는 최대 반복횟수
      • ex) early_stopping_rounds=10으로 설정하고 1000번을 수행하게 되는데, 100번이후로 평가지표가 감소하지 않은데 110번까지도 평가지표가 감소하지 않는다면 중지
    • eval_metric: 반복 수행 시 사용하는 비용 평가 지표
      • eval_metric = 'logloss' or 'auc' 등등 
    • eval_set: 평가를 수행하는 별도의 검증 데이터 세트(Validation Set)
      • 일반적으로 검증 데이터 세트에서 반복적으로 비용 감소 성능 평가

 


5. XGBoost를 이용한 위스콘신 유방암 예측: Python Native XGBoost 이용

 

X_test, y_test: 테스트용 데이터의 피처 데이터셋, 레이블 값

X_train, y_train: 학습용 데이터의 피처 데이터셋, 레이블 값

X_tr, y_tr: X_train, y_train인 학습용 데이터를 다시 학습용 데이터로 분리한 피처 데이터셋, 레이블 값

X_val, y_val: X_train, y_train인 학습용 데이터를 다시 검증용 데이터로 분리한 피처 데이터셋, 레이블 값

 

Dataset Loading

import pandas as pd
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
# xgboost 패키지 로딩하기
import xgboost as xgb
from xgboost import plot_importance

import warnings
warnings.filterwarnings('ignore')

dataset = load_breast_cancer()
features= dataset.data
labels = dataset.target

cancer_df = pd.DataFrame(data=features, columns=dataset.feature_names)
cancer_df['target']= labels
cancer_df.head(3)
print(dataset.target_names)
print(cancer_df['target'].value_counts())

# ['malignant' 'benign']
# 1    357
# 0    212
# Name: target, dtype: int64

 

Early stop에서 사용하기 위해 train data를 다시 train data와 validation data로 분리

# cancer_df에서 feature용 DataFrame과 Label용 Series 객체 추출
# 맨 마지막 칼럼이 Label이므로 Feature용 DataFrame은 cancer_df의 첫번째 칼럼에서 맨 마지막 두번째 컬럼까지를 :-1 슬라이싱으로 추출.
X_features = cancer_df.iloc[:, :-1]
y_label = cancer_df.iloc[:, -1]

# 전체 데이터 중 80%는 학습용 데이터, 20%는 테스트용 데이터 추출
X_train, X_test, y_train, y_test=train_test_split(X_features, y_label, test_size=0.2, random_state=156 )

# 위에서 만든 X_train, y_train을 다시 쪼개서 90%는 학습과 10%는 검증용 데이터로 분리 
X_tr, X_val, y_tr, y_val= train_test_split(X_train, y_train, test_size=0.1, random_state=156 )

print(X_train.shape , X_test.shape)
print(X_tr.shape, X_val.shape)

# (455, 30) (114, 30)
# (409, 30) (46, 30)

 

train_dataset과 test_dataset을 DMatrix로 변환

  • DMatrix는 numpy array, DataFrame에서도 변환 가능
  • xgb.DMatrix(data=feature_dataset, label=labels)
  • 학습용, 검증용, 테스트용 DMatrix
# 만약 구버전 XGBoost에서 DataFrame으로 DMatrix 생성이 안될 경우 X_train.values로 넘파이 변환. 
# 학습, 검증, 테스트용 DMatrix를 생성. 
dtr = xgb.DMatrix(data=X_tr, label=y_tr)
dval = xgb.DMatrix(data=X_val, label=y_val)
dtest = xgb.DMatrix(data=X_test , label=y_test)

 

Hyperparameter 설정

 

  • eta: learning rate 
  • objective: loss function
  • eval_metric: 반복 수행 시 사용하는 비용 평가 지표
params = { 'max_depth':3,
           'eta': 0.05,
           'objective':'binary:logistic',
           'eval_metric':'logloss'
        }
num_rounds = 400

 

 

주어진 Hyperparameter와 early stopping 파라미터를 train() 함수의 파라미터로 전달하고 학습

  • scikit learn의 estimator class와 다르게 python native XGBoost에서는 train() 함수의 파라미터에 넣는다
# 학습 데이터 셋은 'train' 또는 평가 데이터 셋은 'eval' 로 명기합니다. 
eval_list = [(dtr,'train'),(dval,'eval')] # 또는 eval_list = [(dval,'eval')] 만 명기해도 무방. 

# 하이퍼 파라미터와 early stopping 파라미터를 train( ) 함수의 파라미터로 전달
xgb_model = xgb.train(params = params , dtrain=dtr , num_boost_round=num_rounds , \
                      early_stopping_rounds=50, evals=eval_list )

 

 

xgb _model = xgb.train(params, dtrain, num_boost_round, early_stopping_rounds, evals) 

 

  • params: Hyperparameter를 dict 형태로 전달
  • dtrain: DMatrix(data, label) 학습용 DMatrix
  • num_boost_round: 몇 번 반복해서 학습을 할 것인지
  • early_stopping_rounds: 더 이상 평가 지표가 감소하지 않는 최대 반복횟수
  • evals: 평가를 수행하는 별도의 검증 데이터 세트(Validation Set)
    • [(학습용 DMatrix, 'train'), (검증용 DMatrix, 'eval')] or [(검증용 DMatrix, 'eval')] 만 명시해도 ok
    • eval - loss_function: 형태로 loss 값을 반환

 

[0]	train-logloss:0.65016	eval-logloss:0.66183
[1]	train-logloss:0.61131	eval-logloss:0.63609
[2]	train-logloss:0.57563	eval-logloss:0.61144
[3]	train-logloss:0.54310	eval-logloss:0.59204
[4]	train-logloss:0.51323	eval-logloss:0.57329
[5]	train-logloss:0.48447	eval-logloss:0.55037
[6]	train-logloss:0.45796	eval-logloss:0.52930
[7]	train-logloss:0.43436	eval-logloss:0.51534
[8]	train-logloss:0.41150	eval-logloss:0.49718
[9]	train-logloss:0.39027	eval-logloss:0.48154
[10]	train-logloss:0.37128	eval-logloss:0.46990
[11]	train-logloss:0.35254	eval-logloss:0.45474
[12]	train-logloss:0.33528	eval-logloss:0.44229
[13]	train-logloss:0.31892	eval-logloss:0.42961
[14]	train-logloss:0.30439	eval-logloss:0.42065
[15]	train-logloss:0.29000	eval-logloss:0.40958
[16]	train-logloss:0.27651	eval-logloss:0.39887
[17]	train-logloss:0.26389	eval-logloss:0.39050
[18]	train-logloss:0.25210	eval-logloss:0.38254
[19]	train-logloss:0.24123	eval-logloss:0.37393
[20]	train-logloss:0.23076	eval-logloss:0.36789
[21]	train-logloss:0.22091	eval-logloss:0.36017
[22]	train-logloss:0.21155	eval-logloss:0.35421
[23]	train-logloss:0.20263	eval-logloss:0.34683
[24]	train-logloss:0.19434	eval-logloss:0.34111
[25]	train-logloss:0.18637	eval-logloss:0.33634
[26]	train-logloss:0.17875	eval-logloss:0.33082
[27]	train-logloss:0.17167	eval-logloss:0.32675
[28]	train-logloss:0.16481	eval-logloss:0.32099
[29]	train-logloss:0.15835	eval-logloss:0.31671
[30]	train-logloss:0.15225	eval-logloss:0.31277
[31]	train-logloss:0.14650	eval-logloss:0.30882
[32]	train-logloss:0.14102	eval-logloss:0.30437
[33]	train-logloss:0.13590	eval-logloss:0.30103
[34]	train-logloss:0.13109	eval-logloss:0.29794
[35]	train-logloss:0.12647	eval-logloss:0.29499
[36]	train-logloss:0.12197	eval-logloss:0.29295
[37]	train-logloss:0.11784	eval-logloss:0.29043
[38]	train-logloss:0.11379	eval-logloss:0.28927
[39]	train-logloss:0.10994	eval-logloss:0.28578
[40]	train-logloss:0.10638	eval-logloss:0.28364
[41]	train-logloss:0.10302	eval-logloss:0.28183
[42]	train-logloss:0.09963	eval-logloss:0.28005
[43]	train-logloss:0.09649	eval-logloss:0.27972
[44]	train-logloss:0.09359	eval-logloss:0.27744
[45]	train-logloss:0.09080	eval-logloss:0.27542
[46]	train-logloss:0.08807	eval-logloss:0.27504
[47]	train-logloss:0.08541	eval-logloss:0.27458
[48]	train-logloss:0.08299	eval-logloss:0.27348
[49]	train-logloss:0.08035	eval-logloss:0.27247
[50]	train-logloss:0.07786	eval-logloss:0.27163
[51]	train-logloss:0.07550	eval-logloss:0.27094
[52]	train-logloss:0.07344	eval-logloss:0.26967
[53]	train-logloss:0.07147	eval-logloss:0.27008
[54]	train-logloss:0.06964	eval-logloss:0.26890
[55]	train-logloss:0.06766	eval-logloss:0.26854
[56]	train-logloss:0.06591	eval-logloss:0.26900
[57]	train-logloss:0.06433	eval-logloss:0.26790
[58]	train-logloss:0.06259	eval-logloss:0.26663
[59]	train-logloss:0.06107	eval-logloss:0.26743
[60]	train-logloss:0.05957	eval-logloss:0.26610
[61]	train-logloss:0.05817	eval-logloss:0.26644
[62]	train-logloss:0.05691	eval-logloss:0.26673
[63]	train-logloss:0.05550	eval-logloss:0.26550
[64]	train-logloss:0.05422	eval-logloss:0.26443
[65]	train-logloss:0.05311	eval-logloss:0.26500
[66]	train-logloss:0.05207	eval-logloss:0.26591
[67]	train-logloss:0.05093	eval-logloss:0.26501
[68]	train-logloss:0.04976	eval-logloss:0.26435
[69]	train-logloss:0.04872	eval-logloss:0.26360
[70]	train-logloss:0.04776	eval-logloss:0.26319
[71]	train-logloss:0.04680	eval-logloss:0.26255
[72]	train-logloss:0.04580	eval-logloss:0.26204
[73]	train-logloss:0.04484	eval-logloss:0.26254
[74]	train-logloss:0.04388	eval-logloss:0.26289
[75]	train-logloss:0.04309	eval-logloss:0.26249
[76]	train-logloss:0.04224	eval-logloss:0.26217
[77]	train-logloss:0.04133	eval-logloss:0.26166
[78]	train-logloss:0.04050	eval-logloss:0.26179
[79]	train-logloss:0.03967	eval-logloss:0.26103
[80]	train-logloss:0.03876	eval-logloss:0.26094
[81]	train-logloss:0.03806	eval-logloss:0.26148
[82]	train-logloss:0.03740	eval-logloss:0.26054
[83]	train-logloss:0.03676	eval-logloss:0.25967
[84]	train-logloss:0.03605	eval-logloss:0.25905
[85]	train-logloss:0.03545	eval-logloss:0.26007
[86]	train-logloss:0.03489	eval-logloss:0.25984
[87]	train-logloss:0.03425	eval-logloss:0.25933
[88]	train-logloss:0.03361	eval-logloss:0.25932
[89]	train-logloss:0.03311	eval-logloss:0.26002
[90]	train-logloss:0.03260	eval-logloss:0.25936
[91]	train-logloss:0.03202	eval-logloss:0.25886
[92]	train-logloss:0.03152	eval-logloss:0.25918
[93]	train-logloss:0.03107	eval-logloss:0.25864
[94]	train-logloss:0.03049	eval-logloss:0.25951
[95]	train-logloss:0.03007	eval-logloss:0.26091
[96]	train-logloss:0.02963	eval-logloss:0.26014
[97]	train-logloss:0.02913	eval-logloss:0.25974
[98]	train-logloss:0.02866	eval-logloss:0.25937
[99]	train-logloss:0.02829	eval-logloss:0.25893
[100]	train-logloss:0.02789	eval-logloss:0.25928
[101]	train-logloss:0.02751	eval-logloss:0.25955
[102]	train-logloss:0.02714	eval-logloss:0.25901
[103]	train-logloss:0.02668	eval-logloss:0.25991
[104]	train-logloss:0.02634	eval-logloss:0.25950
[105]	train-logloss:0.02594	eval-logloss:0.25924
[106]	train-logloss:0.02556	eval-logloss:0.25901
[107]	train-logloss:0.02522	eval-logloss:0.25738
[108]	train-logloss:0.02492	eval-logloss:0.25702
[109]	train-logloss:0.02453	eval-logloss:0.25789
[110]	train-logloss:0.02418	eval-logloss:0.25770
[111]	train-logloss:0.02384	eval-logloss:0.25842
[112]	train-logloss:0.02356	eval-logloss:0.25810
[113]	train-logloss:0.02322	eval-logloss:0.25848
[114]	train-logloss:0.02290	eval-logloss:0.25833
[115]	train-logloss:0.02260	eval-logloss:0.25820
[116]	train-logloss:0.02229	eval-logloss:0.25905
[117]	train-logloss:0.02204	eval-logloss:0.25878
[118]	train-logloss:0.02176	eval-logloss:0.25728
[119]	train-logloss:0.02149	eval-logloss:0.25722
[120]	train-logloss:0.02119	eval-logloss:0.25764
[121]	train-logloss:0.02095	eval-logloss:0.25761
[122]	train-logloss:0.02067	eval-logloss:0.25832
[123]	train-logloss:0.02045	eval-logloss:0.25808
[124]	train-logloss:0.02023	eval-logloss:0.25855
[125]	train-logloss:0.01998	eval-logloss:0.25714
[126]	train-logloss:0.01973	eval-logloss:0.25587
[127]	train-logloss:0.01946	eval-logloss:0.25640
[128]	train-logloss:0.01927	eval-logloss:0.25685
[129]	train-logloss:0.01908	eval-logloss:0.25665
[130]	train-logloss:0.01886	eval-logloss:0.25712
[131]	train-logloss:0.01863	eval-logloss:0.25609
[132]	train-logloss:0.01839	eval-logloss:0.25649
[133]	train-logloss:0.01816	eval-logloss:0.25789
[134]	train-logloss:0.01802	eval-logloss:0.25811
[135]	train-logloss:0.01785	eval-logloss:0.25794
[136]	train-logloss:0.01763	eval-logloss:0.25876
[137]	train-logloss:0.01748	eval-logloss:0.25884
[138]	train-logloss:0.01732	eval-logloss:0.25867
[139]	train-logloss:0.01719	eval-logloss:0.25876
[140]	train-logloss:0.01696	eval-logloss:0.25987
[141]	train-logloss:0.01681	eval-logloss:0.25960
[142]	train-logloss:0.01669	eval-logloss:0.25982
[143]	train-logloss:0.01656	eval-logloss:0.25992
[144]	train-logloss:0.01638	eval-logloss:0.26035
[145]	train-logloss:0.01623	eval-logloss:0.26055
[146]	train-logloss:0.01606	eval-logloss:0.26092
[147]	train-logloss:0.01589	eval-logloss:0.26137
[148]	train-logloss:0.01572	eval-logloss:0.25999
[149]	train-logloss:0.01556	eval-logloss:0.26028
[150]	train-logloss:0.01546	eval-logloss:0.26048
[151]	train-logloss:0.01531	eval-logloss:0.26142
[152]	train-logloss:0.01515	eval-logloss:0.26188
[153]	train-logloss:0.01501	eval-logloss:0.26227
[154]	train-logloss:0.01486	eval-logloss:0.26287
[155]	train-logloss:0.01476	eval-logloss:0.26299
[156]	train-logloss:0.01462	eval-logloss:0.26346
[157]	train-logloss:0.01448	eval-logloss:0.26379
[158]	train-logloss:0.01434	eval-logloss:0.26306
[159]	train-logloss:0.01424	eval-logloss:0.26237
[160]	train-logloss:0.01410	eval-logloss:0.26251
[161]	train-logloss:0.01401	eval-logloss:0.26265
[162]	train-logloss:0.01392	eval-logloss:0.26264
[163]	train-logloss:0.01380	eval-logloss:0.26250
[164]	train-logloss:0.01372	eval-logloss:0.26264
[165]	train-logloss:0.01359	eval-logloss:0.26255
[166]	train-logloss:0.01350	eval-logloss:0.26188
[167]	train-logloss:0.01342	eval-logloss:0.26203
[168]	train-logloss:0.01331	eval-logloss:0.26190
[169]	train-logloss:0.01319	eval-logloss:0.26184
[170]	train-logloss:0.01312	eval-logloss:0.26133
[171]	train-logloss:0.01304	eval-logloss:0.26148
[172]	train-logloss:0.01297	eval-logloss:0.26157
[173]	train-logloss:0.01285	eval-logloss:0.26253
[174]	train-logloss:0.01278	eval-logloss:0.26229
[175]	train-logloss:0.01267	eval-logloss:0.26086
[176]	train-logloss:0.01258	eval-logloss:0.26103

 

이 코드의 경우  XGBoost는 특정 반복 횟수(50) 만큼 더 이상 loss function이 감소하지 않으면 지정된 반복횟수를 다 완료하지 않고 수행을 종료할 수 있음

[126] 이후로 eval - logloss의 값이 줄어들지 않는 것을 확인할 수 있기 때문

 

 

predict()를 통해 예측 결과 확률 값을 반환하고 예측 값으로 변환

xgb_model.predict(테스트용 DMatrix)

 

pred_probs = xgb_model.predict(dtest)
print('predict( ) 수행 결과값을 10개만 표시, 예측 확률 값으로 표시됨')
print(np.round(pred_probs[:10],3))

# 예측 확률이 0.5 보다 크면 1 , 그렇지 않으면 0 으로 예측값 결정하여 List 객체인 preds에 저장 
preds = [ 1 if x > 0.5 else 0 for x in pred_probs ]
print('예측값 10개만 표시:',preds[:10])

# predict( ) 수행 결과값을 10개만 표시, 예측 확률 값으로 표시됨
# [0.845 0.008 0.68  0.081 0.975 0.999 0.998 0.998 0.996 0.001]
# 예측값 10개만 표시: [1, 0, 1, 0, 1, 1, 1, 1, 1, 0]

 

 

get_clf_eval()을 통해 예측 평가

from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import f1_score, roc_auc_score

def get_clf_eval(y_test, pred=None, pred_proba=None):
    confusion = confusion_matrix( y_test, pred)
    accuracy = accuracy_score(y_test , pred)
    precision = precision_score(y_test , pred)
    recall = recall_score(y_test , pred)
    f1 = f1_score(y_test,pred)
    # ROC-AUC 추가 
    roc_auc = roc_auc_score(y_test, pred_proba)
    print('오차 행렬')
    print(confusion)
    # ROC-AUC print 추가
    print('정확도: {0:.4f}, 정밀도: {1:.4f}, 재현율: {2:.4f},\
    F1: {3:.4f}, AUC:{4:.4f}'.format(accuracy, precision, recall, f1, roc_auc))
get_clf_eval(y_test , preds, pred_probs)

# 오차 행렬
# [[34  3]
#  [ 2 75]]
# 정확도: 0.9561, 정밀도: 0.9615, 재현율: 0.9740,    F1: 0.9677, AUC:0.9937

 

 

Feature Importance 시각화

import matplotlib.pyplot as plt
%matplotlib inline

fig, ax = plt.subplots(figsize=(10, 12))
plot_importance(xgb_model, ax=ax)
plt.savefig('p239_xgb_feature_importance.tif', format='tif', dpi=300, bbox_inches='tight')

 

 

 

 

 

 


6. XGBoost를 이용한 위스콘신 유방암 예측: Scikit Learn Wrapper XGBoost 이용

 

     

Python Native와의 차이점

  • DMatrix 사용하지 않음
  • HyperParameter들을 train()에 넣어주는 것이 아니라 Class의 Parameter로 바로 넣어줄 수 있음

 

 

Scikit Learn Wrapper Class import, 학습 및 예측

verbose

  • 함수 인자로 verbose가 있으면 함수 수행시 발생하는 상세한 정보들을 표준 출력으로 자세히 내보낼 것인가를 나타냄
  • 보통 0 은 출력하지 않고, 1은 자세히, 2는 함축적인 정보만 출력하는 형태로 되어 있음

 

 

# 사이킷런 래퍼 XGBoost 클래스인 XGBClassifier 임포트
from xgboost import XGBClassifier

# Warning 메시지를 없애기 위해 eval_metric 값을 XGBClassifier 생성 인자로 입력. 미 입력해도 수행에 문제 없음.   
xgb_wrapper = XGBClassifier(n_estimators=400, learning_rate=0.05, max_depth=3, eval_metric='logloss')
xgb_wrapper.fit(X_train, y_train, verbose=True)
w_preds = xgb_wrapper.predict(X_test)
w_pred_proba = xgb_wrapper.predict_proba(X_test)[:, 1]

 

get_clf_eval(y_test , w_preds, w_pred_proba)

# 오차 행렬
# [[34  3]
#  [ 1 76]]
# 정확도: 0.9649, 정밀도: 0.9620, 재현율: 0.9870,    F1: 0.9744, AUC:0.9954

 

 

early stopping을 50으로 설정하고 재 학습/예측/평가

XGBClassifier.fit(X_tr, y_tr, early_stopping_rounds, eval_metric, eval_set, verbose)

 

fit() 함수의 HyperParameter

  • early_stopping_rounds: 더 이상 평가 지표가 감소하지 않는 최대 반복횟수
  • eval_metric: 반복 수행 시 사용하는 비용 평가 지표
  • eval_set: 평가를 수행하는 별도의 검증 데이터 세트(Validation Set)
    • [(X_tr, y_tr), (X_val, y_val)] 형태

 

 

from xgboost import XGBClassifier

xgb_wrapper = XGBClassifier(n_estimators=400, learning_rate=0.05, max_depth=3)
evals = [(X_tr, y_tr), (X_val, y_val)]
xgb_wrapper.fit(X_tr, y_tr, early_stopping_rounds=50, eval_metric="logloss", 
                eval_set=evals, verbose=True)

ws50_preds = xgb_wrapper.predict(X_test)
ws50_pred_proba = xgb_wrapper.predict_proba(X_test)[:, 1]
[0]	validation_0-logloss:0.65016	validation_1-logloss:0.66183
[1]	validation_0-logloss:0.61131	validation_1-logloss:0.63609
[2]	validation_0-logloss:0.57563	validation_1-logloss:0.61144
[3]	validation_0-logloss:0.54310	validation_1-logloss:0.59204
[4]	validation_0-logloss:0.51323	validation_1-logloss:0.57329
[5]	validation_0-logloss:0.48447	validation_1-logloss:0.55037
[6]	validation_0-logloss:0.45796	validation_1-logloss:0.52929
[7]	validation_0-logloss:0.43436	validation_1-logloss:0.51534
[8]	validation_0-logloss:0.41150	validation_1-logloss:0.49718
[9]	validation_0-logloss:0.39027	validation_1-logloss:0.48154
[10]	validation_0-logloss:0.37128	validation_1-logloss:0.46990
[11]	validation_0-logloss:0.35254	validation_1-logloss:0.45474
[12]	validation_0-logloss:0.33528	validation_1-logloss:0.44229
[13]	validation_0-logloss:0.31893	validation_1-logloss:0.42961
[14]	validation_0-logloss:0.30439	validation_1-logloss:0.42065
[15]	validation_0-logloss:0.29000	validation_1-logloss:0.40958
[16]	validation_0-logloss:0.27651	validation_1-logloss:0.39887
[17]	validation_0-logloss:0.26389	validation_1-logloss:0.39050
[18]	validation_0-logloss:0.25210	validation_1-logloss:0.38254
[19]	validation_0-logloss:0.24123	validation_1-logloss:0.37393
[20]	validation_0-logloss:0.23076	validation_1-logloss:0.36789
[21]	validation_0-logloss:0.22091	validation_1-logloss:0.36017
[22]	validation_0-logloss:0.21155	validation_1-logloss:0.35421
[23]	validation_0-logloss:0.20263	validation_1-logloss:0.34683
[24]	validation_0-logloss:0.19434	validation_1-logloss:0.34111
[25]	validation_0-logloss:0.18637	validation_1-logloss:0.33634
[26]	validation_0-logloss:0.17875	validation_1-logloss:0.33082
[27]	validation_0-logloss:0.17167	validation_1-logloss:0.32675
[28]	validation_0-logloss:0.16481	validation_1-logloss:0.32099
[29]	validation_0-logloss:0.15835	validation_1-logloss:0.31671
[30]	validation_0-logloss:0.15225	validation_1-logloss:0.31277
[31]	validation_0-logloss:0.14650	validation_1-logloss:0.30882
[32]	validation_0-logloss:0.14102	validation_1-logloss:0.30437
[33]	validation_0-logloss:0.13590	validation_1-logloss:0.30103
[34]	validation_0-logloss:0.13109	validation_1-logloss:0.29794
[35]	validation_0-logloss:0.12647	validation_1-logloss:0.29499
[36]	validation_0-logloss:0.12197	validation_1-logloss:0.29295
[37]	validation_0-logloss:0.11784	validation_1-logloss:0.29043
[38]	validation_0-logloss:0.11379	validation_1-logloss:0.28927
[39]	validation_0-logloss:0.10994	validation_1-logloss:0.28578
[40]	validation_0-logloss:0.10638	validation_1-logloss:0.28364
[41]	validation_0-logloss:0.10302	validation_1-logloss:0.28183
[42]	validation_0-logloss:0.09963	validation_1-logloss:0.28005
[43]	validation_0-logloss:0.09649	validation_1-logloss:0.27972
[44]	validation_0-logloss:0.09359	validation_1-logloss:0.27744
[45]	validation_0-logloss:0.09080	validation_1-logloss:0.27542
[46]	validation_0-logloss:0.08807	validation_1-logloss:0.27504
[47]	validation_0-logloss:0.08541	validation_1-logloss:0.27458
[48]	validation_0-logloss:0.08299	validation_1-logloss:0.27348
[49]	validation_0-logloss:0.08035	validation_1-logloss:0.27247
[50]	validation_0-logloss:0.07786	validation_1-logloss:0.27163
[51]	validation_0-logloss:0.07550	validation_1-logloss:0.27094
[52]	validation_0-logloss:0.07344	validation_1-logloss:0.26967
[53]	validation_0-logloss:0.07147	validation_1-logloss:0.27008
[54]	validation_0-logloss:0.06964	validation_1-logloss:0.26890
[55]	validation_0-logloss:0.06766	validation_1-logloss:0.26854
[56]	validation_0-logloss:0.06592	validation_1-logloss:0.26900
[57]	validation_0-logloss:0.06433	validation_1-logloss:0.26790
[58]	validation_0-logloss:0.06259	validation_1-logloss:0.26663
[59]	validation_0-logloss:0.06107	validation_1-logloss:0.26743
[60]	validation_0-logloss:0.05957	validation_1-logloss:0.26610
[61]	validation_0-logloss:0.05817	validation_1-logloss:0.26644
[62]	validation_0-logloss:0.05691	validation_1-logloss:0.26673
[63]	validation_0-logloss:0.05550	validation_1-logloss:0.26550
[64]	validation_0-logloss:0.05422	validation_1-logloss:0.26443
[65]	validation_0-logloss:0.05311	validation_1-logloss:0.26500
[66]	validation_0-logloss:0.05207	validation_1-logloss:0.26591
[67]	validation_0-logloss:0.05093	validation_1-logloss:0.26501
[68]	validation_0-logloss:0.04976	validation_1-logloss:0.26435
[69]	validation_0-logloss:0.04872	validation_1-logloss:0.26360
[70]	validation_0-logloss:0.04776	validation_1-logloss:0.26319
[71]	validation_0-logloss:0.04680	validation_1-logloss:0.26255
[72]	validation_0-logloss:0.04580	validation_1-logloss:0.26204
[73]	validation_0-logloss:0.04484	validation_1-logloss:0.26254
[74]	validation_0-logloss:0.04388	validation_1-logloss:0.26289
[75]	validation_0-logloss:0.04309	validation_1-logloss:0.26249
[76]	validation_0-logloss:0.04224	validation_1-logloss:0.26217
[77]	validation_0-logloss:0.04133	validation_1-logloss:0.26166
[78]	validation_0-logloss:0.04050	validation_1-logloss:0.26179
[79]	validation_0-logloss:0.03967	validation_1-logloss:0.26103
[80]	validation_0-logloss:0.03877	validation_1-logloss:0.26094
[81]	validation_0-logloss:0.03806	validation_1-logloss:0.26148
[82]	validation_0-logloss:0.03740	validation_1-logloss:0.26054
[83]	validation_0-logloss:0.03676	validation_1-logloss:0.25967
[84]	validation_0-logloss:0.03605	validation_1-logloss:0.25905
[85]	validation_0-logloss:0.03545	validation_1-logloss:0.26007
[86]	validation_0-logloss:0.03488	validation_1-logloss:0.25984
[87]	validation_0-logloss:0.03425	validation_1-logloss:0.25933
[88]	validation_0-logloss:0.03361	validation_1-logloss:0.25932
[89]	validation_0-logloss:0.03311	validation_1-logloss:0.26002
[90]	validation_0-logloss:0.03260	validation_1-logloss:0.25936
[91]	validation_0-logloss:0.03202	validation_1-logloss:0.25886
[92]	validation_0-logloss:0.03152	validation_1-logloss:0.25918
[93]	validation_0-logloss:0.03107	validation_1-logloss:0.25865
[94]	validation_0-logloss:0.03049	validation_1-logloss:0.25951
[95]	validation_0-logloss:0.03007	validation_1-logloss:0.26091
[96]	validation_0-logloss:0.02963	validation_1-logloss:0.26014
[97]	validation_0-logloss:0.02913	validation_1-logloss:0.25974
[98]	validation_0-logloss:0.02866	validation_1-logloss:0.25937
[99]	validation_0-logloss:0.02829	validation_1-logloss:0.25893
[100]	validation_0-logloss:0.02789	validation_1-logloss:0.25928
[101]	validation_0-logloss:0.02751	validation_1-logloss:0.25955
[102]	validation_0-logloss:0.02714	validation_1-logloss:0.25901
[103]	validation_0-logloss:0.02668	validation_1-logloss:0.25991
[104]	validation_0-logloss:0.02634	validation_1-logloss:0.25950
[105]	validation_0-logloss:0.02594	validation_1-logloss:0.25924
[106]	validation_0-logloss:0.02556	validation_1-logloss:0.25901
[107]	validation_0-logloss:0.02522	validation_1-logloss:0.25738
[108]	validation_0-logloss:0.02492	validation_1-logloss:0.25702
[109]	validation_0-logloss:0.02453	validation_1-logloss:0.25789
[110]	validation_0-logloss:0.02418	validation_1-logloss:0.25770
[111]	validation_0-logloss:0.02384	validation_1-logloss:0.25842
[112]	validation_0-logloss:0.02356	validation_1-logloss:0.25810
[113]	validation_0-logloss:0.02322	validation_1-logloss:0.25848
[114]	validation_0-logloss:0.02290	validation_1-logloss:0.25833
[115]	validation_0-logloss:0.02260	validation_1-logloss:0.25820
[116]	validation_0-logloss:0.02229	validation_1-logloss:0.25905
[117]	validation_0-logloss:0.02204	validation_1-logloss:0.25878
[118]	validation_0-logloss:0.02176	validation_1-logloss:0.25728
[119]	validation_0-logloss:0.02149	validation_1-logloss:0.25722
[120]	validation_0-logloss:0.02119	validation_1-logloss:0.25764
[121]	validation_0-logloss:0.02095	validation_1-logloss:0.25761
[122]	validation_0-logloss:0.02067	validation_1-logloss:0.25832
[123]	validation_0-logloss:0.02045	validation_1-logloss:0.25808
[124]	validation_0-logloss:0.02023	validation_1-logloss:0.25855
[125]	validation_0-logloss:0.01998	validation_1-logloss:0.25714
[126]	validation_0-logloss:0.01973	validation_1-logloss:0.25587
[127]	validation_0-logloss:0.01946	validation_1-logloss:0.25640
[128]	validation_0-logloss:0.01927	validation_1-logloss:0.25685
[129]	validation_0-logloss:0.01908	validation_1-logloss:0.25665
[130]	validation_0-logloss:0.01886	validation_1-logloss:0.25712
[131]	validation_0-logloss:0.01863	validation_1-logloss:0.25609
[132]	validation_0-logloss:0.01839	validation_1-logloss:0.25649
[133]	validation_0-logloss:0.01816	validation_1-logloss:0.25789
[134]	validation_0-logloss:0.01802	validation_1-logloss:0.25811
[135]	validation_0-logloss:0.01785	validation_1-logloss:0.25794
[136]	validation_0-logloss:0.01763	validation_1-logloss:0.25876
[137]	validation_0-logloss:0.01748	validation_1-logloss:0.25884
[138]	validation_0-logloss:0.01732	validation_1-logloss:0.25867
[139]	validation_0-logloss:0.01719	validation_1-logloss:0.25876
[140]	validation_0-logloss:0.01696	validation_1-logloss:0.25987
[141]	validation_0-logloss:0.01681	validation_1-logloss:0.25960
[142]	validation_0-logloss:0.01669	validation_1-logloss:0.25982
[143]	validation_0-logloss:0.01656	validation_1-logloss:0.25992
[144]	validation_0-logloss:0.01638	validation_1-logloss:0.26035
[145]	validation_0-logloss:0.01623	validation_1-logloss:0.26055
[146]	validation_0-logloss:0.01606	validation_1-logloss:0.26092
[147]	validation_0-logloss:0.01589	validation_1-logloss:0.26137
[148]	validation_0-logloss:0.01572	validation_1-logloss:0.25999
[149]	validation_0-logloss:0.01557	validation_1-logloss:0.26028
[150]	validation_0-logloss:0.01546	validation_1-logloss:0.26048
[151]	validation_0-logloss:0.01531	validation_1-logloss:0.26142
[152]	validation_0-logloss:0.01515	validation_1-logloss:0.26188
[153]	validation_0-logloss:0.01501	validation_1-logloss:0.26227
[154]	validation_0-logloss:0.01486	validation_1-logloss:0.26287
[155]	validation_0-logloss:0.01476	validation_1-logloss:0.26299
[156]	validation_0-logloss:0.01461	validation_1-logloss:0.26346
[157]	validation_0-logloss:0.01448	validation_1-logloss:0.26379
[158]	validation_0-logloss:0.01434	validation_1-logloss:0.26306
[159]	validation_0-logloss:0.01424	validation_1-logloss:0.26237
[160]	validation_0-logloss:0.01410	validation_1-logloss:0.26251
[161]	validation_0-logloss:0.01401	validation_1-logloss:0.26265
[162]	validation_0-logloss:0.01392	validation_1-logloss:0.26264
[163]	validation_0-logloss:0.01380	validation_1-logloss:0.26250
[164]	validation_0-logloss:0.01372	validation_1-logloss:0.26264
[165]	validation_0-logloss:0.01359	validation_1-logloss:0.26255
[166]	validation_0-logloss:0.01350	validation_1-logloss:0.26188
[167]	validation_0-logloss:0.01342	validation_1-logloss:0.26203
[168]	validation_0-logloss:0.01331	validation_1-logloss:0.26190
[169]	validation_0-logloss:0.01319	validation_1-logloss:0.26184
[170]	validation_0-logloss:0.01312	validation_1-logloss:0.26133
[171]	validation_0-logloss:0.01304	validation_1-logloss:0.26148
[172]	validation_0-logloss:0.01297	validation_1-logloss:0.26157
[173]	validation_0-logloss:0.01285	validation_1-logloss:0.26253
[174]	validation_0-logloss:0.01278	validation_1-logloss:0.26229
[175]	validation_0-logloss:0.01267	validation_1-logloss:0.26086
get_clf_eval(y_test , ws50_preds, ws50_pred_proba)

# 오차 행렬
# [[34  3]
#  [ 2 75]]
# 정확도: 0.9561, 정밀도: 0.9615, 재현율: 0.9740,    F1: 0.9677, AUC:0.9933

 

 

early stopping을 10으로 설정하고 재 학습/예측/평가

 

early stopping을 너무 작게 하면 성능이 더 떨어질 수 있음

When?) HyperParameter를 빠르게 튜닝해야 할 때 -> 그 후 조금 더 Depth 있게 튜닝

# early_stopping_rounds를 10으로 설정하고 재 학습. 
xgb_wrapper.fit(X_tr, y_tr, early_stopping_rounds=10, 
                eval_metric="logloss", eval_set=evals,verbose=True)

ws10_preds = xgb_wrapper.predict(X_test)
ws10_pred_proba = xgb_wrapper.predict_proba(X_test)[:, 1]
get_clf_eval(y_test , ws10_preds, ws10_pred_proba)
from xgboost import plot_importance
import matplotlib.pyplot as plt
%matplotlib inline

fig, ax = plt.subplots(figsize=(10, 12))
# 사이킷런 래퍼 클래스를 입력해도 무방. 
plot_importance(xgb_wrapper, ax=ax)

 

 

 

xgb_wrapper.evals_result( )

  • evals_result( )를 사용하여 .fit( ) 수행시 나오는 train_data_loss와 validation_data_loss에 접근가능
  • results = xgb_wrapper.evals_result( )
    • results['validation_0']['logloss']: 학습 데이터의 손실함수 값
    • results['validation_1']['logloss']: 검증 데이터의 손실함수 값

 

import matplotlib.pyplot as plt

results = xgb_wrapper.evals_result()

plt.plot(results["validation_0"]["logloss"], label="train")
plt.plot(results["validation_1"]["logloss"], label="test")
plt.legend()
plt.show()

 

 

 


7. LightGBM

 

Gradient Boosting Framework로 Tree 기반 학습 알고리즘

 

 

 

XGBoost 대비 장점

  • 더 빠른 학습과 예측 수행 시간
  • 더 작은 메모리 사용량
  • 카테고리형 피처의 자동 변환과 최적 분할
    • (One-Hot Encoding등을 사용하지 않고도 카테고리형 피처를 최적으로 변환하고 이에 따른 노드 분할 수행)

 

 

 

LightBGM 트리 분할 방식

 

기존의 GBM(XGBoost포함) 계열들은 균형 트리 분할(Level Wise): depth를 최소화

LightGBM은 리프 중심 트리 분할(Leaf Wise):

- 특정 방향성을 가지고 리프 노드들을 계속 분할한다면 예측 오류를 줄일 수 있음

- 더 빠른 수행 성능과 향상된 예측이 가능하다

- 속도가 빠르다는 것이 가장 큰 장점.

 

-데이터 양이 많아지는 상황에서 빠른 결과를 얻는데 시간이 점점 많이 걸리고 있는데,

Light GBM은 큰 사이즈의 데이터를 다룰 수 있고 실행시킬 때 적은 메모리를 차지한다

 

 

LightBGM Python 구현

 

 

 

LightGBM HyperParameter

 

 

 

 

LightGBM Scikit Learn Wrapper는

XGBoost Scikit Learn Wrapper에 대해 해당 HyperParameter가 있으면 이를 그대로 사용하고

그렇지 않으면 Python Wrapper LightGBM HyperParameter를 사용

 

 

  • num_leaves: 최대 leaf node 개수
  • boost_from_average
    • True일 경우 레이블 값이 극도로 불균형 분포를 이루는 경우 재현율 및 ROC-AUC 성능이 매우 저하됨. 레이블 값이 극도로 불균형할 경우 boost_from_average를 False로 설정하는 것이 유리
    • LightGBM 2.1.0 이상 버젼에서 boost_from_average가 True가 Default가 됨

 

 

모델의 복잡도를 줄이는 기본 튜닝 방안

 

num_leaves(최대 리프노드의 개수)를 중심으로 min_child_samples(min_data_in_leaf, 리프노드가 될 수 있는 최소 데이터 건수), max_depth를 함께 조절하면서 모델의 복잡도를 줄이는 것이 기본 튜닝 방안

 

  • num_leaves를 늘리면 정확도가 높아지지만 트리가 깊어지고 과접합되기 쉬움
  • min_child_samples(min_data_in_leaf)를 크게 설정하면 트리가 깊어지는 것을 방지
  • max_depth는 명시적으로 깊이를 제한. 위의 두 파라미터와 함꼐 과적합을 개선하는데 사용

또한, learning_rate을 줄이면서 n_estimator를 크게하는 것은 부스팅에서의 기본적인 튜닝 방안

 

 

 

LightGBM HyperParamter 튜닝 개요

 

 

 

Scikit Learn Wrapper LightGBM HyperParameter

 

 

Scikit Learn XGBoost의 HyperParameter를 먼저 따라가고,

만약 없다면 Python Wrapper LightGBM의 HyperParameter를 사용

 

 

 

 


8. LightGBM을 이용한 위스콘신 유방암 예측

 

 

LightGBM 적용 - 위스콘신 Breast Cancer Prediction

from lightgbm import LGBMClassifier

# LightGBM의 파이썬 패키지인 lightgbm에서 LGBMClassifier 임포트
from lightgbm import LGBMClassifier

import pandas as pd
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

dataset = load_breast_cancer()

cancer_df = pd.DataFrame(data=dataset.data, columns=dataset.feature_names)
cancer_df['target']= dataset.target

cancer_df.head()

 

X_test, y_test: 테스트용 데이터의 피처 데이터셋, 레이블 값

X_train, y_train: 학습용 데이터의 피처 데이터셋, 레이블 값

X_tr, y_tr: X_train, y_train인 학습용 데이터를 다시 학습용 데이터로 분리한 피처 데이터셋, 레이블 값

X_val, y_val: X_train, y_train인 학습용 데이터를 다시 검증용 데이터로 분리한 피처 데이터셋, 레이블 값

 

X_features = cancer_df.iloc[:, :-1]
y_label = cancer_df.iloc[:, -1]

# 전체 데이터 중 80%는 학습용 데이터, 20%는 테스트용 데이터 추출
X_train, X_test, y_train, y_test=train_test_split(X_features, y_label,
                                         test_size=0.2, random_state=156 )

# 위에서 만든 X_train, y_train을 다시 쪼개서 90%는 학습과 10%는 검증용 데이터로 분리  
X_tr, X_val, y_tr, y_val= train_test_split(X_train, y_train,
                                         test_size=0.1, random_state=156 )

# 앞서 XGBoost와 동일하게 n_estimators는 400 설정. 
lgbm_wrapper = LGBMClassifier(n_estimators=400, learning_rate=0.05)

# LightGBM도 XGBoost와 동일하게 조기 중단 수행 가능. 
evals = [(X_tr, y_tr), (X_val, y_val)]
lgbm_wrapper.fit(X_tr, y_tr, early_stopping_rounds=50, eval_metric="logloss", 
                 eval_set=evals, verbose=True)
preds = lgbm_wrapper.predict(X_test)
pred_proba = lgbm_wrapper.predict_proba(X_test)[:, 1]

 

[fit() 함수의 HyperParameter]

  • early_stopping_rounds: 더 이상 평가 지표가 감소하지 않는 최대 반복횟수
  • eval_metric: 반복 수행 시 사용하는 비용 평가 지표
  • eval_set: 평가를 수행하는 별도의 검증 데이터 세트(Validation Set)
    • [(X_tr, y_tr), (X_val, y_val)] = [(학습용 데이터 셋), (검증용 데이터 셋)] 형태

 

Early Stopping

[61]이후로 평가지표가 더 이상 감소하지 않음 (early_stopping_rounds=50 이므로)

[1]	training's binary_logloss: 0.625671	valid_1's binary_logloss: 0.628248
[2]	training's binary_logloss: 0.588173	valid_1's binary_logloss: 0.601106
[3]	training's binary_logloss: 0.554518	valid_1's binary_logloss: 0.577587
[4]	training's binary_logloss: 0.523972	valid_1's binary_logloss: 0.556324
[5]	training's binary_logloss: 0.49615	valid_1's binary_logloss: 0.537407
[6]	training's binary_logloss: 0.470108	valid_1's binary_logloss: 0.519401
[7]	training's binary_logloss: 0.446647	valid_1's binary_logloss: 0.502637
[8]	training's binary_logloss: 0.425055	valid_1's binary_logloss: 0.488311
[9]	training's binary_logloss: 0.405125	valid_1's binary_logloss: 0.474664
[10]	training's binary_logloss: 0.386526	valid_1's binary_logloss: 0.461267
[11]	training's binary_logloss: 0.367027	valid_1's binary_logloss: 0.444274
[12]	training's binary_logloss: 0.350713	valid_1's binary_logloss: 0.432755
[13]	training's binary_logloss: 0.334601	valid_1's binary_logloss: 0.421371
[14]	training's binary_logloss: 0.319854	valid_1's binary_logloss: 0.411418
[15]	training's binary_logloss: 0.306374	valid_1's binary_logloss: 0.402989
[16]	training's binary_logloss: 0.293116	valid_1's binary_logloss: 0.393973
[17]	training's binary_logloss: 0.280812	valid_1's binary_logloss: 0.384801
[18]	training's binary_logloss: 0.268352	valid_1's binary_logloss: 0.376191
[19]	training's binary_logloss: 0.256942	valid_1's binary_logloss: 0.368378
[20]	training's binary_logloss: 0.246443	valid_1's binary_logloss: 0.362062
[21]	training's binary_logloss: 0.236874	valid_1's binary_logloss: 0.355162
[22]	training's binary_logloss: 0.227501	valid_1's binary_logloss: 0.348933
[23]	training's binary_logloss: 0.218988	valid_1's binary_logloss: 0.342819
[24]	training's binary_logloss: 0.210621	valid_1's binary_logloss: 0.337386
[25]	training's binary_logloss: 0.202076	valid_1's binary_logloss: 0.331523
[26]	training's binary_logloss: 0.194199	valid_1's binary_logloss: 0.326349
[27]	training's binary_logloss: 0.187107	valid_1's binary_logloss: 0.322785
[28]	training's binary_logloss: 0.180535	valid_1's binary_logloss: 0.317877
[29]	training's binary_logloss: 0.173834	valid_1's binary_logloss: 0.313928
[30]	training's binary_logloss: 0.167198	valid_1's binary_logloss: 0.310105
[31]	training's binary_logloss: 0.161229	valid_1's binary_logloss: 0.307107
[32]	training's binary_logloss: 0.155494	valid_1's binary_logloss: 0.303837
[33]	training's binary_logloss: 0.149125	valid_1's binary_logloss: 0.300315
[34]	training's binary_logloss: 0.144045	valid_1's binary_logloss: 0.297816
[35]	training's binary_logloss: 0.139341	valid_1's binary_logloss: 0.295387
[36]	training's binary_logloss: 0.134625	valid_1's binary_logloss: 0.293063
[37]	training's binary_logloss: 0.129167	valid_1's binary_logloss: 0.289127
[38]	training's binary_logloss: 0.12472	valid_1's binary_logloss: 0.288697
[39]	training's binary_logloss: 0.11974	valid_1's binary_logloss: 0.28576
[40]	training's binary_logloss: 0.115054	valid_1's binary_logloss: 0.282853
[41]	training's binary_logloss: 0.110662	valid_1's binary_logloss: 0.279441
[42]	training's binary_logloss: 0.106358	valid_1's binary_logloss: 0.28113
[43]	training's binary_logloss: 0.102324	valid_1's binary_logloss: 0.279139
[44]	training's binary_logloss: 0.0985699	valid_1's binary_logloss: 0.276465
[45]	training's binary_logloss: 0.094858	valid_1's binary_logloss: 0.275946
[46]	training's binary_logloss: 0.0912486	valid_1's binary_logloss: 0.272819
[47]	training's binary_logloss: 0.0883115	valid_1's binary_logloss: 0.272306
[48]	training's binary_logloss: 0.0849963	valid_1's binary_logloss: 0.270452
[49]	training's binary_logloss: 0.0821742	valid_1's binary_logloss: 0.268671
[50]	training's binary_logloss: 0.0789991	valid_1's binary_logloss: 0.267587
[51]	training's binary_logloss: 0.0761072	valid_1's binary_logloss: 0.26626
[52]	training's binary_logloss: 0.0732567	valid_1's binary_logloss: 0.265542
[53]	training's binary_logloss: 0.0706388	valid_1's binary_logloss: 0.264547
[54]	training's binary_logloss: 0.0683911	valid_1's binary_logloss: 0.26502
[55]	training's binary_logloss: 0.0659347	valid_1's binary_logloss: 0.264388
[56]	training's binary_logloss: 0.0636873	valid_1's binary_logloss: 0.263128
[57]	training's binary_logloss: 0.0613354	valid_1's binary_logloss: 0.26231
[58]	training's binary_logloss: 0.0591944	valid_1's binary_logloss: 0.262011
[59]	training's binary_logloss: 0.057033	valid_1's binary_logloss: 0.261454
[60]	training's binary_logloss: 0.0550801	valid_1's binary_logloss: 0.260746
[61]	training's binary_logloss: 0.0532381	valid_1's binary_logloss: 0.260236
[62]	training's binary_logloss: 0.0514074	valid_1's binary_logloss: 0.261586
[63]	training's binary_logloss: 0.0494837	valid_1's binary_logloss: 0.261797
[64]	training's binary_logloss: 0.0477826	valid_1's binary_logloss: 0.262533
[65]	training's binary_logloss: 0.0460364	valid_1's binary_logloss: 0.263305
[66]	training's binary_logloss: 0.0444552	valid_1's binary_logloss: 0.264072
[67]	training's binary_logloss: 0.0427638	valid_1's binary_logloss: 0.266223
[68]	training's binary_logloss: 0.0412449	valid_1's binary_logloss: 0.266817
[69]	training's binary_logloss: 0.0398589	valid_1's binary_logloss: 0.267819
[70]	training's binary_logloss: 0.0383095	valid_1's binary_logloss: 0.267484
[71]	training's binary_logloss: 0.0368803	valid_1's binary_logloss: 0.270233
[72]	training's binary_logloss: 0.0355637	valid_1's binary_logloss: 0.268442
[73]	training's binary_logloss: 0.0341747	valid_1's binary_logloss: 0.26895
[74]	training's binary_logloss: 0.0328302	valid_1's binary_logloss: 0.266958
[75]	training's binary_logloss: 0.0317853	valid_1's binary_logloss: 0.268091
[76]	training's binary_logloss: 0.0305626	valid_1's binary_logloss: 0.266419
[77]	training's binary_logloss: 0.0295001	valid_1's binary_logloss: 0.268588
[78]	training's binary_logloss: 0.0284699	valid_1's binary_logloss: 0.270964
[79]	training's binary_logloss: 0.0273953	valid_1's binary_logloss: 0.270293
[80]	training's binary_logloss: 0.0264668	valid_1's binary_logloss: 0.270523
[81]	training's binary_logloss: 0.0254636	valid_1's binary_logloss: 0.270683
[82]	training's binary_logloss: 0.0245911	valid_1's binary_logloss: 0.273187
[83]	training's binary_logloss: 0.0236486	valid_1's binary_logloss: 0.275994
[84]	training's binary_logloss: 0.0228047	valid_1's binary_logloss: 0.274053
[85]	training's binary_logloss: 0.0221693	valid_1's binary_logloss: 0.273211
[86]	training's binary_logloss: 0.0213043	valid_1's binary_logloss: 0.272626
[87]	training's binary_logloss: 0.0203934	valid_1's binary_logloss: 0.27534
[88]	training's binary_logloss: 0.0195552	valid_1's binary_logloss: 0.276228
[89]	training's binary_logloss: 0.0188623	valid_1's binary_logloss: 0.27525
[90]	training's binary_logloss: 0.0183664	valid_1's binary_logloss: 0.276485
[91]	training's binary_logloss: 0.0176788	valid_1's binary_logloss: 0.277052
[92]	training's binary_logloss: 0.0170059	valid_1's binary_logloss: 0.277686
[93]	training's binary_logloss: 0.0164317	valid_1's binary_logloss: 0.275332
[94]	training's binary_logloss: 0.015878	valid_1's binary_logloss: 0.276236
[95]	training's binary_logloss: 0.0152959	valid_1's binary_logloss: 0.274538
[96]	training's binary_logloss: 0.0147216	valid_1's binary_logloss: 0.275244
[97]	training's binary_logloss: 0.0141758	valid_1's binary_logloss: 0.275829
[98]	training's binary_logloss: 0.0136551	valid_1's binary_logloss: 0.276654
[99]	training's binary_logloss: 0.0131585	valid_1's binary_logloss: 0.277859
[100]	training's binary_logloss: 0.0126961	valid_1's binary_logloss: 0.279265
[101]	training's binary_logloss: 0.0122421	valid_1's binary_logloss: 0.276695
[102]	training's binary_logloss: 0.0118067	valid_1's binary_logloss: 0.278488
[103]	training's binary_logloss: 0.0113994	valid_1's binary_logloss: 0.278932
[104]	training's binary_logloss: 0.0109799	valid_1's binary_logloss: 0.280997
[105]	training's binary_logloss: 0.0105953	valid_1's binary_logloss: 0.281454
[106]	training's binary_logloss: 0.0102381	valid_1's binary_logloss: 0.282058
[107]	training's binary_logloss: 0.00986714	valid_1's binary_logloss: 0.279275
[108]	training's binary_logloss: 0.00950998	valid_1's binary_logloss: 0.281427
[109]	training's binary_logloss: 0.00915965	valid_1's binary_logloss: 0.280752
[110]	training's binary_logloss: 0.00882581	valid_1's binary_logloss: 0.282152
[111]	training's binary_logloss: 0.00850714	valid_1's binary_logloss: 0.280894

 

 

데이터의 수가 적을 경우에는 LightGBM이 XGBoost보다는 성능이 조금 더 떨어질 수 있다

get_clf_eval(y_test, preds, pred_proba)

# 오차 행렬
# [[34  3]
#  [ 2 75]]
# 정확도: 0.9561, 정밀도: 0.9615, 재현율: 0.9740,    F1: 0.9677, AUC:0.9877

 

# plot_importance( )를 이용하여 feature 중요도 시각화
from lightgbm import plot_importance
import matplotlib.pyplot as plt
%matplotlib inline

fig, ax = plt.subplots(figsize=(10, 12))
plot_importance(lgbm_wrapper, ax=ax)
plt.savefig('lightgbm_feature_importance.tif', format='tif', dpi=300, bbox_inches='tight)
plt.show()

 

 

 

plot_metric, plot_tree 사용

 

from lightgbm import plot_metric, plot_tree

plot_metric(lgbm_wrapper, figsize=(6,4))
plot_tree(lgbm_wrapper, figsize=(6,4))

 

# 조기 중단 없이 학습을 수행한 경우
evals = [(X_tr, y_tr), (X_val, y_val)]
lgbm_wrapper.fit(X_tr, y_tr, eval_metric='logloss', eval_set=evals, verbose=True)
preds = lgbm_wrapper.predict(X_test)
pred_proba = lgbm_wrapper.predict_proba(X_test)[:,1]

plot_metric(lgbm_wrapper, figsize=(6,4))
plot_tree(lgbm_wrapper, figsize=(6,4))

 

 

728x90