1.Boosting의 이해
Boosting 알고리즘은 여러 개의 약한 학습기(weak learner)를 순차적으로 학습-예측하면서 잘못 예측한 데이터나 학습 트리에 가중치 부여를 통해 오류를 개선해 나가면서 학습하는 방식
Boosting의 대표적인 구현은 AdaBoost(Adaptive boosting), Gradient Boost(XGBoost, LightGBM) 가 있음
Voting = 서로 다른 알고리즘을 지닌 분류기들을 결합하는 것
Bagging = 각각의 분류기가 모두 같은 유형의 알고리즘 기반이지만, 데이터 샘플링을 서로 다르게 가져가면서 개별적으로 학습을 수행해 최종적으로 모든 분류기가 보팅을 수행하여 예측결정을 하는 것
Boosting의 종류
- AdaBoost
- Gradient Boosting
- GBM(Gradient Boost Machine)
- XGBoost
- LightGBM
GBM의 문제
- 약한 학습기들이 순차적으로 학습하기 때문에 수행 시간이 꽤나 걸린다
- 반면에 Random Forest는 병렬로 학습하기 때문에 시간이 덜 걸린다
- Overfitting이 발생할 가능성이 크다
2.AdaBoosting
AdaBoosting의 학습 / 예측 프로세스
3.GBM(Gradient Boost Machine)
GBM(Gradient Boost Machine)도 AdaBoost와 유사하나,
가중치 업데이트를 경사 하강법(Gradient Descent)를 이용하여 최적화된 결과를 얻는 알고리즘
오류 값 = 실제 값(target 값 $y$) - 예측 값($y_pred$)이다.
분류의 실제 결과값을 $y$, 피처를 $X_{1}, X_{2}, ...... , X_{n}$ 이라 하고, 이 피처에 기반한 예측 함수를 $F(x)$라 한다면 오류식 $h(x) = y - F(x)$이 된다.
이 오류식 $h(x) = y - F(x)$을 최소화하는 방향성을 가지고 반복적으로 가중치 값을 업데이트 하는 것이 경사 하강법(Gradient Descent)이다.
경사 하강법은 반복 수행을 통해 오류를 최소화할 수 있도록 가중치의 업데이트 값을 도출하는 기법으로서 머신러닝에서 중요한 기법 중 하나이다.
Scikit Learn GBM 주요 Hyperparameter 및 튜닝
Scikit Learn은 GBM 분류를 위해 GradientBoostingClassifier 클래스를 제공함
- loss: 경사 하강법에서 사용할 loss function을 지정함. 특별한 이유가 없으면 기본값이 'deviance'를 그대로 적용함
- learning_rate: GBM이 학습을 진행할 때마다 적용하는 학습률
- Weak learner가 순차적으로 오류 값을 보정해 나가는 데 적용하는 계수
- 0~1 사이의 값을 지정할 수 있으며 기본값은 0.1
- 작을수록 성능이 좋아지지만 시간이 오래걸림
- 너무 작은 값을 적용하면
- 업데이트 되는 값이 작아져서 최소 오류 값을 찾아 예측 성능이 높아질 가능성이 높음
- 하지만 weak learner는 순차적인 반복이 필요해서 수행 시간이 오래 걸리고, 또 너무 작게 설정하면 모든 weak learner의 반복이 완료돼도 최소 오류 값을 찾지 못할 수 있음
- 너무 큰 값을 적용하면
- 최소 오류 값을 찾지 못하고 그냥 지나쳐 버려 예측 성능이 떨어질 가능성이 높아짐
- 반면에 빠른 수행이 가능함
- n_estimators: weak learner의 개수
- weak learner가 순차적으로 오류를 보정하므로 개수가 많을수록 예측 성능이 일정 수준까지는 좋아질 수 있음
- 하지만 개수가 많을수록 수행 시간이 오래 걸림
- Default 값은 100
- subsample: weak learner가 학습에 사용하는 데이터의 샘플링 비율
- Default 값은 1, 이는 전체 학습 데이터를 기반으로 학습한다는 의미
- 0.5이면 학습 데이터의 절반을 기반으로 학습한다는 의미
- Overfitting이 염려되는 경우 subsample을 1보다 작은 값으로 설정
from sklearn.ensemble import GradientBoostingClassifier
import time
import warnings
warnings.filterwarnings("ignore")
X_train, X_test, y_train, y_test = get_human_dataset()
# GBM 수행 시간 측정을 위함. 시작 시간 설정.
start_time = time.time()
gb_clf = GradientBoostingClassifier(random_state=0)
gb_clf.fit(X_train, y_train)
gb_pred = gb_clf.predict(X_test)
gb_accuracy = accuracy_score(y_test, gb_pred)
print("GBM 정확도: {0:.4f}".format(gb_accuracy))
print("GBM 수행 시간: {0:.1f} 초 ".format(time.time() - start_time))
# GBM 정확도: 0.9393
# GBM 수행 시간: 963.5 초
### 아래는 강의에서 설명드리지는 않지만 GridSearchCV로 GBM의 하이퍼 파라미터 튜닝을 수행하는 예제 입니다.
### 사이킷런이 1.X로 업그레이드 되며서 GBM의 학습 속도가 현저하게 저하되는 문제가 오히려 발생합니다.
### 아래는 수행 시간이 오래 걸리므로 참고용으로만 사용하시면 좋을 것 같습니다.
from sklearn.model_selection import GridSearchCV
params = {
'n_estimators':[100, 500],
'learning_rate' : [ 0.05, 0.1]
}
grid_cv = GridSearchCV(gb_clf , param_grid=params , cv=2 ,verbose=1)
grid_cv.fit(X_train , y_train)
print('최적 하이퍼 파라미터:\n', grid_cv.best_params_)
print('최고 예측 정확도: {0:.4f}'.format(grid_cv.best_score_))
# GridSearchCV를 이용하여 최적으로 학습된 estimator로 predict 수행.
gb_pred = grid_cv.best_estimator_.predict(X_test)
gb_accuracy = accuracy_score(y_test, gb_pred)
print('GBM 정확도: {0:.4f}'.format(gb_accuracy))
GridSearchCV가 XGBoost, LigthGBM에 적합하지 않는 이유
- Gradient Boosting 기반 알고리즘은 튜닝해야 할 하이퍼 파라미터 개수가 많고 범위가 넓어서 가능한 개별 경우의 수가 너무 많음
- 이렇듯 경우의 수가 많은 경우 데이터가 크면 GridSearchCV로 하이퍼 파라미터 튜닝에 굉장히 오랜 시간이 투입되어야 한다
4.XGBoost(eXtra Gradient Boost)
표준 GBM이 가지고 있던 문제를 해결하기 위해 발전된 형태의 GBM이 등장
XGBoost의 주요 장점
- 뛰어난 예측 성능
- GBM 대비 빠른 수행 시간 (LightGBM이나 일부 앙상블 기법이 더 빠름)
- CPU 병렬 처리, GPU 지원
- 다양한 성능 향상 기능
- 규제(Regularization) 기능 탑재 (Overfitting 방지)
- Tree Pruning (Tree의 노드들을 재검증: 가지치기)
- 다양한 편의 기능
- 조기 중단(Early Stopping)
- 자체 내장된 교차 검증
- 결손값 자체 처리
XGBoost의 python 구현
[API] Python Wrapper XGB vs Scikit Learn Wrapper XGB
HyperParameter를 입력받는 시기가 다르다:
- Python Wrapper XGB는 train()시 입력받음
- Scikit Learn Wrapper XGB는 객체 생성시 & fit()시 입력받음
학습 API의 반환 여부가 다르다:
- Python Wrapper XGB는 학습된 객체를 반환 받아야 함
- Scikit Learn Wrapper는 fit() method만 실행시켜도 됨
예측 API의 반환 값이 다르다:
- Python Wrapper XGB는 각 Label별 예측 결과 확률값을 반환
- Scikit Learn Wrapper는 예측 결과 Label 값을 반환
[Hyperparameter] Python Wrapper XGB vs Scikit Learn Wrapper XGB
Overfitting 문제가 심각한 경우
규제 (Regulation): 오차를 줄이기 위해 학습을 하다 보면 overfitting이 나올 확률이 높기에 오차에 너무 집중하지 않도록 규제를 걺
XGBoost의 Early Stopping
XGBoost는 특정 반복 횟수 만큼 더 이상 loss function이 감소하지 않으면 지정된 반복횟수를 다 완료하지 않고 수행을 종료할 수 있음
- 학습을 위한 시간을 단축시킬 수 있음. 특히 최적화 튜닝 단계에서 적절하게 사용 가능
- 너무 반복 횟수를 단축할 경우 예측 성능 최적화가 안된 상태에서 학습이 종료될 수 있으므로 유의 필요
- 조기 중단 설정을 위한 주요 파라미터
- early_stopping_rounds: 더 이상 평가 지표가 감소하지 않는 최대 반복횟수
- ex) early_stopping_rounds=10으로 설정하고 1000번을 수행하게 되는데, 100번이후로 평가지표가 감소하지 않은데 110번까지도 평가지표가 감소하지 않는다면 중지
- eval_metric: 반복 수행 시 사용하는 비용 평가 지표
- eval_metric = 'logloss' or 'auc' 등등
- eval_set: 평가를 수행하는 별도의 검증 데이터 세트(Validation Set)
- 일반적으로 검증 데이터 세트에서 반복적으로 비용 감소 성능 평가
- early_stopping_rounds: 더 이상 평가 지표가 감소하지 않는 최대 반복횟수
5. XGBoost를 이용한 위스콘신 유방암 예측: Python Native XGBoost 이용
X_test, y_test: 테스트용 데이터의 피처 데이터셋, 레이블 값
X_train, y_train: 학습용 데이터의 피처 데이터셋, 레이블 값
X_tr, y_tr: X_train, y_train인 학습용 데이터를 다시 학습용 데이터로 분리한 피처 데이터셋, 레이블 값
X_val, y_val: X_train, y_train인 학습용 데이터를 다시 검증용 데이터로 분리한 피처 데이터셋, 레이블 값
Dataset Loading
import pandas as pd
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
# xgboost 패키지 로딩하기
import xgboost as xgb
from xgboost import plot_importance
import warnings
warnings.filterwarnings('ignore')
dataset = load_breast_cancer()
features= dataset.data
labels = dataset.target
cancer_df = pd.DataFrame(data=features, columns=dataset.feature_names)
cancer_df['target']= labels
cancer_df.head(3)
print(dataset.target_names)
print(cancer_df['target'].value_counts())
# ['malignant' 'benign']
# 1 357
# 0 212
# Name: target, dtype: int64
Early stop에서 사용하기 위해 train data를 다시 train data와 validation data로 분리
# cancer_df에서 feature용 DataFrame과 Label용 Series 객체 추출
# 맨 마지막 칼럼이 Label이므로 Feature용 DataFrame은 cancer_df의 첫번째 칼럼에서 맨 마지막 두번째 컬럼까지를 :-1 슬라이싱으로 추출.
X_features = cancer_df.iloc[:, :-1]
y_label = cancer_df.iloc[:, -1]
# 전체 데이터 중 80%는 학습용 데이터, 20%는 테스트용 데이터 추출
X_train, X_test, y_train, y_test=train_test_split(X_features, y_label, test_size=0.2, random_state=156 )
# 위에서 만든 X_train, y_train을 다시 쪼개서 90%는 학습과 10%는 검증용 데이터로 분리
X_tr, X_val, y_tr, y_val= train_test_split(X_train, y_train, test_size=0.1, random_state=156 )
print(X_train.shape , X_test.shape)
print(X_tr.shape, X_val.shape)
# (455, 30) (114, 30)
# (409, 30) (46, 30)
train_dataset과 test_dataset을 DMatrix로 변환
- DMatrix는 numpy array, DataFrame에서도 변환 가능
- xgb.DMatrix(data=feature_dataset, label=labels)
- 학습용, 검증용, 테스트용 DMatrix
# 만약 구버전 XGBoost에서 DataFrame으로 DMatrix 생성이 안될 경우 X_train.values로 넘파이 변환.
# 학습, 검증, 테스트용 DMatrix를 생성.
dtr = xgb.DMatrix(data=X_tr, label=y_tr)
dval = xgb.DMatrix(data=X_val, label=y_val)
dtest = xgb.DMatrix(data=X_test , label=y_test)
Hyperparameter 설정
- eta: learning rate
- objective: loss function
- eval_metric: 반복 수행 시 사용하는 비용 평가 지표
params = { 'max_depth':3,
'eta': 0.05,
'objective':'binary:logistic',
'eval_metric':'logloss'
}
num_rounds = 400
주어진 Hyperparameter와 early stopping 파라미터를 train() 함수의 파라미터로 전달하고 학습
- scikit learn의 estimator class와 다르게 python native XGBoost에서는 train() 함수의 파라미터에 넣는다
# 학습 데이터 셋은 'train' 또는 평가 데이터 셋은 'eval' 로 명기합니다.
eval_list = [(dtr,'train'),(dval,'eval')] # 또는 eval_list = [(dval,'eval')] 만 명기해도 무방.
# 하이퍼 파라미터와 early stopping 파라미터를 train( ) 함수의 파라미터로 전달
xgb_model = xgb.train(params = params , dtrain=dtr , num_boost_round=num_rounds , \
early_stopping_rounds=50, evals=eval_list )
xgb _model = xgb.train(params, dtrain, num_boost_round, early_stopping_rounds, evals)
- params: Hyperparameter를 dict 형태로 전달
- dtrain: DMatrix(data, label) 학습용 DMatrix
- num_boost_round: 몇 번 반복해서 학습을 할 것인지
- early_stopping_rounds: 더 이상 평가 지표가 감소하지 않는 최대 반복횟수
- evals: 평가를 수행하는 별도의 검증 데이터 세트(Validation Set)
- [(학습용 DMatrix, 'train'), (검증용 DMatrix, 'eval')] or [(검증용 DMatrix, 'eval')] 만 명시해도 ok
- eval - loss_function: 형태로 loss 값을 반환
[0] train-logloss:0.65016 eval-logloss:0.66183
[1] train-logloss:0.61131 eval-logloss:0.63609
[2] train-logloss:0.57563 eval-logloss:0.61144
[3] train-logloss:0.54310 eval-logloss:0.59204
[4] train-logloss:0.51323 eval-logloss:0.57329
[5] train-logloss:0.48447 eval-logloss:0.55037
[6] train-logloss:0.45796 eval-logloss:0.52930
[7] train-logloss:0.43436 eval-logloss:0.51534
[8] train-logloss:0.41150 eval-logloss:0.49718
[9] train-logloss:0.39027 eval-logloss:0.48154
[10] train-logloss:0.37128 eval-logloss:0.46990
[11] train-logloss:0.35254 eval-logloss:0.45474
[12] train-logloss:0.33528 eval-logloss:0.44229
[13] train-logloss:0.31892 eval-logloss:0.42961
[14] train-logloss:0.30439 eval-logloss:0.42065
[15] train-logloss:0.29000 eval-logloss:0.40958
[16] train-logloss:0.27651 eval-logloss:0.39887
[17] train-logloss:0.26389 eval-logloss:0.39050
[18] train-logloss:0.25210 eval-logloss:0.38254
[19] train-logloss:0.24123 eval-logloss:0.37393
[20] train-logloss:0.23076 eval-logloss:0.36789
[21] train-logloss:0.22091 eval-logloss:0.36017
[22] train-logloss:0.21155 eval-logloss:0.35421
[23] train-logloss:0.20263 eval-logloss:0.34683
[24] train-logloss:0.19434 eval-logloss:0.34111
[25] train-logloss:0.18637 eval-logloss:0.33634
[26] train-logloss:0.17875 eval-logloss:0.33082
[27] train-logloss:0.17167 eval-logloss:0.32675
[28] train-logloss:0.16481 eval-logloss:0.32099
[29] train-logloss:0.15835 eval-logloss:0.31671
[30] train-logloss:0.15225 eval-logloss:0.31277
[31] train-logloss:0.14650 eval-logloss:0.30882
[32] train-logloss:0.14102 eval-logloss:0.30437
[33] train-logloss:0.13590 eval-logloss:0.30103
[34] train-logloss:0.13109 eval-logloss:0.29794
[35] train-logloss:0.12647 eval-logloss:0.29499
[36] train-logloss:0.12197 eval-logloss:0.29295
[37] train-logloss:0.11784 eval-logloss:0.29043
[38] train-logloss:0.11379 eval-logloss:0.28927
[39] train-logloss:0.10994 eval-logloss:0.28578
[40] train-logloss:0.10638 eval-logloss:0.28364
[41] train-logloss:0.10302 eval-logloss:0.28183
[42] train-logloss:0.09963 eval-logloss:0.28005
[43] train-logloss:0.09649 eval-logloss:0.27972
[44] train-logloss:0.09359 eval-logloss:0.27744
[45] train-logloss:0.09080 eval-logloss:0.27542
[46] train-logloss:0.08807 eval-logloss:0.27504
[47] train-logloss:0.08541 eval-logloss:0.27458
[48] train-logloss:0.08299 eval-logloss:0.27348
[49] train-logloss:0.08035 eval-logloss:0.27247
[50] train-logloss:0.07786 eval-logloss:0.27163
[51] train-logloss:0.07550 eval-logloss:0.27094
[52] train-logloss:0.07344 eval-logloss:0.26967
[53] train-logloss:0.07147 eval-logloss:0.27008
[54] train-logloss:0.06964 eval-logloss:0.26890
[55] train-logloss:0.06766 eval-logloss:0.26854
[56] train-logloss:0.06591 eval-logloss:0.26900
[57] train-logloss:0.06433 eval-logloss:0.26790
[58] train-logloss:0.06259 eval-logloss:0.26663
[59] train-logloss:0.06107 eval-logloss:0.26743
[60] train-logloss:0.05957 eval-logloss:0.26610
[61] train-logloss:0.05817 eval-logloss:0.26644
[62] train-logloss:0.05691 eval-logloss:0.26673
[63] train-logloss:0.05550 eval-logloss:0.26550
[64] train-logloss:0.05422 eval-logloss:0.26443
[65] train-logloss:0.05311 eval-logloss:0.26500
[66] train-logloss:0.05207 eval-logloss:0.26591
[67] train-logloss:0.05093 eval-logloss:0.26501
[68] train-logloss:0.04976 eval-logloss:0.26435
[69] train-logloss:0.04872 eval-logloss:0.26360
[70] train-logloss:0.04776 eval-logloss:0.26319
[71] train-logloss:0.04680 eval-logloss:0.26255
[72] train-logloss:0.04580 eval-logloss:0.26204
[73] train-logloss:0.04484 eval-logloss:0.26254
[74] train-logloss:0.04388 eval-logloss:0.26289
[75] train-logloss:0.04309 eval-logloss:0.26249
[76] train-logloss:0.04224 eval-logloss:0.26217
[77] train-logloss:0.04133 eval-logloss:0.26166
[78] train-logloss:0.04050 eval-logloss:0.26179
[79] train-logloss:0.03967 eval-logloss:0.26103
[80] train-logloss:0.03876 eval-logloss:0.26094
[81] train-logloss:0.03806 eval-logloss:0.26148
[82] train-logloss:0.03740 eval-logloss:0.26054
[83] train-logloss:0.03676 eval-logloss:0.25967
[84] train-logloss:0.03605 eval-logloss:0.25905
[85] train-logloss:0.03545 eval-logloss:0.26007
[86] train-logloss:0.03489 eval-logloss:0.25984
[87] train-logloss:0.03425 eval-logloss:0.25933
[88] train-logloss:0.03361 eval-logloss:0.25932
[89] train-logloss:0.03311 eval-logloss:0.26002
[90] train-logloss:0.03260 eval-logloss:0.25936
[91] train-logloss:0.03202 eval-logloss:0.25886
[92] train-logloss:0.03152 eval-logloss:0.25918
[93] train-logloss:0.03107 eval-logloss:0.25864
[94] train-logloss:0.03049 eval-logloss:0.25951
[95] train-logloss:0.03007 eval-logloss:0.26091
[96] train-logloss:0.02963 eval-logloss:0.26014
[97] train-logloss:0.02913 eval-logloss:0.25974
[98] train-logloss:0.02866 eval-logloss:0.25937
[99] train-logloss:0.02829 eval-logloss:0.25893
[100] train-logloss:0.02789 eval-logloss:0.25928
[101] train-logloss:0.02751 eval-logloss:0.25955
[102] train-logloss:0.02714 eval-logloss:0.25901
[103] train-logloss:0.02668 eval-logloss:0.25991
[104] train-logloss:0.02634 eval-logloss:0.25950
[105] train-logloss:0.02594 eval-logloss:0.25924
[106] train-logloss:0.02556 eval-logloss:0.25901
[107] train-logloss:0.02522 eval-logloss:0.25738
[108] train-logloss:0.02492 eval-logloss:0.25702
[109] train-logloss:0.02453 eval-logloss:0.25789
[110] train-logloss:0.02418 eval-logloss:0.25770
[111] train-logloss:0.02384 eval-logloss:0.25842
[112] train-logloss:0.02356 eval-logloss:0.25810
[113] train-logloss:0.02322 eval-logloss:0.25848
[114] train-logloss:0.02290 eval-logloss:0.25833
[115] train-logloss:0.02260 eval-logloss:0.25820
[116] train-logloss:0.02229 eval-logloss:0.25905
[117] train-logloss:0.02204 eval-logloss:0.25878
[118] train-logloss:0.02176 eval-logloss:0.25728
[119] train-logloss:0.02149 eval-logloss:0.25722
[120] train-logloss:0.02119 eval-logloss:0.25764
[121] train-logloss:0.02095 eval-logloss:0.25761
[122] train-logloss:0.02067 eval-logloss:0.25832
[123] train-logloss:0.02045 eval-logloss:0.25808
[124] train-logloss:0.02023 eval-logloss:0.25855
[125] train-logloss:0.01998 eval-logloss:0.25714
[126] train-logloss:0.01973 eval-logloss:0.25587
[127] train-logloss:0.01946 eval-logloss:0.25640
[128] train-logloss:0.01927 eval-logloss:0.25685
[129] train-logloss:0.01908 eval-logloss:0.25665
[130] train-logloss:0.01886 eval-logloss:0.25712
[131] train-logloss:0.01863 eval-logloss:0.25609
[132] train-logloss:0.01839 eval-logloss:0.25649
[133] train-logloss:0.01816 eval-logloss:0.25789
[134] train-logloss:0.01802 eval-logloss:0.25811
[135] train-logloss:0.01785 eval-logloss:0.25794
[136] train-logloss:0.01763 eval-logloss:0.25876
[137] train-logloss:0.01748 eval-logloss:0.25884
[138] train-logloss:0.01732 eval-logloss:0.25867
[139] train-logloss:0.01719 eval-logloss:0.25876
[140] train-logloss:0.01696 eval-logloss:0.25987
[141] train-logloss:0.01681 eval-logloss:0.25960
[142] train-logloss:0.01669 eval-logloss:0.25982
[143] train-logloss:0.01656 eval-logloss:0.25992
[144] train-logloss:0.01638 eval-logloss:0.26035
[145] train-logloss:0.01623 eval-logloss:0.26055
[146] train-logloss:0.01606 eval-logloss:0.26092
[147] train-logloss:0.01589 eval-logloss:0.26137
[148] train-logloss:0.01572 eval-logloss:0.25999
[149] train-logloss:0.01556 eval-logloss:0.26028
[150] train-logloss:0.01546 eval-logloss:0.26048
[151] train-logloss:0.01531 eval-logloss:0.26142
[152] train-logloss:0.01515 eval-logloss:0.26188
[153] train-logloss:0.01501 eval-logloss:0.26227
[154] train-logloss:0.01486 eval-logloss:0.26287
[155] train-logloss:0.01476 eval-logloss:0.26299
[156] train-logloss:0.01462 eval-logloss:0.26346
[157] train-logloss:0.01448 eval-logloss:0.26379
[158] train-logloss:0.01434 eval-logloss:0.26306
[159] train-logloss:0.01424 eval-logloss:0.26237
[160] train-logloss:0.01410 eval-logloss:0.26251
[161] train-logloss:0.01401 eval-logloss:0.26265
[162] train-logloss:0.01392 eval-logloss:0.26264
[163] train-logloss:0.01380 eval-logloss:0.26250
[164] train-logloss:0.01372 eval-logloss:0.26264
[165] train-logloss:0.01359 eval-logloss:0.26255
[166] train-logloss:0.01350 eval-logloss:0.26188
[167] train-logloss:0.01342 eval-logloss:0.26203
[168] train-logloss:0.01331 eval-logloss:0.26190
[169] train-logloss:0.01319 eval-logloss:0.26184
[170] train-logloss:0.01312 eval-logloss:0.26133
[171] train-logloss:0.01304 eval-logloss:0.26148
[172] train-logloss:0.01297 eval-logloss:0.26157
[173] train-logloss:0.01285 eval-logloss:0.26253
[174] train-logloss:0.01278 eval-logloss:0.26229
[175] train-logloss:0.01267 eval-logloss:0.26086
[176] train-logloss:0.01258 eval-logloss:0.26103
이 코드의 경우 XGBoost는 특정 반복 횟수(50) 만큼 더 이상 loss function이 감소하지 않으면 지정된 반복횟수를 다 완료하지 않고 수행을 종료할 수 있음
[126] 이후로 eval - logloss의 값이 줄어들지 않는 것을 확인할 수 있기 때문
predict()를 통해 예측 결과 확률 값을 반환하고 예측 값으로 변환
xgb_model.predict(테스트용 DMatrix)
pred_probs = xgb_model.predict(dtest)
print('predict( ) 수행 결과값을 10개만 표시, 예측 확률 값으로 표시됨')
print(np.round(pred_probs[:10],3))
# 예측 확률이 0.5 보다 크면 1 , 그렇지 않으면 0 으로 예측값 결정하여 List 객체인 preds에 저장
preds = [ 1 if x > 0.5 else 0 for x in pred_probs ]
print('예측값 10개만 표시:',preds[:10])
# predict( ) 수행 결과값을 10개만 표시, 예측 확률 값으로 표시됨
# [0.845 0.008 0.68 0.081 0.975 0.999 0.998 0.998 0.996 0.001]
# 예측값 10개만 표시: [1, 0, 1, 0, 1, 1, 1, 1, 1, 0]
get_clf_eval()을 통해 예측 평가
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import f1_score, roc_auc_score
def get_clf_eval(y_test, pred=None, pred_proba=None):
confusion = confusion_matrix( y_test, pred)
accuracy = accuracy_score(y_test , pred)
precision = precision_score(y_test , pred)
recall = recall_score(y_test , pred)
f1 = f1_score(y_test,pred)
# ROC-AUC 추가
roc_auc = roc_auc_score(y_test, pred_proba)
print('오차 행렬')
print(confusion)
# ROC-AUC print 추가
print('정확도: {0:.4f}, 정밀도: {1:.4f}, 재현율: {2:.4f},\
F1: {3:.4f}, AUC:{4:.4f}'.format(accuracy, precision, recall, f1, roc_auc))
get_clf_eval(y_test , preds, pred_probs)
# 오차 행렬
# [[34 3]
# [ 2 75]]
# 정확도: 0.9561, 정밀도: 0.9615, 재현율: 0.9740, F1: 0.9677, AUC:0.9937
Feature Importance 시각화
import matplotlib.pyplot as plt
%matplotlib inline
fig, ax = plt.subplots(figsize=(10, 12))
plot_importance(xgb_model, ax=ax)
plt.savefig('p239_xgb_feature_importance.tif', format='tif', dpi=300, bbox_inches='tight')
6. XGBoost를 이용한 위스콘신 유방암 예측: Scikit Learn Wrapper XGBoost 이용
Python Native와의 차이점
- DMatrix 사용하지 않음
- HyperParameter들을 train()에 넣어주는 것이 아니라 Class의 Parameter로 바로 넣어줄 수 있음
Scikit Learn Wrapper Class import, 학습 및 예측
verbose
- 함수 인자로 verbose가 있으면 함수 수행시 발생하는 상세한 정보들을 표준 출력으로 자세히 내보낼 것인가를 나타냄
- 보통 0 은 출력하지 않고, 1은 자세히, 2는 함축적인 정보만 출력하는 형태로 되어 있음
# 사이킷런 래퍼 XGBoost 클래스인 XGBClassifier 임포트
from xgboost import XGBClassifier
# Warning 메시지를 없애기 위해 eval_metric 값을 XGBClassifier 생성 인자로 입력. 미 입력해도 수행에 문제 없음.
xgb_wrapper = XGBClassifier(n_estimators=400, learning_rate=0.05, max_depth=3, eval_metric='logloss')
xgb_wrapper.fit(X_train, y_train, verbose=True)
w_preds = xgb_wrapper.predict(X_test)
w_pred_proba = xgb_wrapper.predict_proba(X_test)[:, 1]
get_clf_eval(y_test , w_preds, w_pred_proba)
# 오차 행렬
# [[34 3]
# [ 1 76]]
# 정확도: 0.9649, 정밀도: 0.9620, 재현율: 0.9870, F1: 0.9744, AUC:0.9954
early stopping을 50으로 설정하고 재 학습/예측/평가
XGBClassifier.fit(X_tr, y_tr, early_stopping_rounds, eval_metric, eval_set, verbose)
fit() 함수의 HyperParameter
- early_stopping_rounds: 더 이상 평가 지표가 감소하지 않는 최대 반복횟수
- eval_metric: 반복 수행 시 사용하는 비용 평가 지표
- eval_set: 평가를 수행하는 별도의 검증 데이터 세트(Validation Set)
- [(X_tr, y_tr), (X_val, y_val)] 형태
from xgboost import XGBClassifier
xgb_wrapper = XGBClassifier(n_estimators=400, learning_rate=0.05, max_depth=3)
evals = [(X_tr, y_tr), (X_val, y_val)]
xgb_wrapper.fit(X_tr, y_tr, early_stopping_rounds=50, eval_metric="logloss",
eval_set=evals, verbose=True)
ws50_preds = xgb_wrapper.predict(X_test)
ws50_pred_proba = xgb_wrapper.predict_proba(X_test)[:, 1]
[0] validation_0-logloss:0.65016 validation_1-logloss:0.66183
[1] validation_0-logloss:0.61131 validation_1-logloss:0.63609
[2] validation_0-logloss:0.57563 validation_1-logloss:0.61144
[3] validation_0-logloss:0.54310 validation_1-logloss:0.59204
[4] validation_0-logloss:0.51323 validation_1-logloss:0.57329
[5] validation_0-logloss:0.48447 validation_1-logloss:0.55037
[6] validation_0-logloss:0.45796 validation_1-logloss:0.52929
[7] validation_0-logloss:0.43436 validation_1-logloss:0.51534
[8] validation_0-logloss:0.41150 validation_1-logloss:0.49718
[9] validation_0-logloss:0.39027 validation_1-logloss:0.48154
[10] validation_0-logloss:0.37128 validation_1-logloss:0.46990
[11] validation_0-logloss:0.35254 validation_1-logloss:0.45474
[12] validation_0-logloss:0.33528 validation_1-logloss:0.44229
[13] validation_0-logloss:0.31893 validation_1-logloss:0.42961
[14] validation_0-logloss:0.30439 validation_1-logloss:0.42065
[15] validation_0-logloss:0.29000 validation_1-logloss:0.40958
[16] validation_0-logloss:0.27651 validation_1-logloss:0.39887
[17] validation_0-logloss:0.26389 validation_1-logloss:0.39050
[18] validation_0-logloss:0.25210 validation_1-logloss:0.38254
[19] validation_0-logloss:0.24123 validation_1-logloss:0.37393
[20] validation_0-logloss:0.23076 validation_1-logloss:0.36789
[21] validation_0-logloss:0.22091 validation_1-logloss:0.36017
[22] validation_0-logloss:0.21155 validation_1-logloss:0.35421
[23] validation_0-logloss:0.20263 validation_1-logloss:0.34683
[24] validation_0-logloss:0.19434 validation_1-logloss:0.34111
[25] validation_0-logloss:0.18637 validation_1-logloss:0.33634
[26] validation_0-logloss:0.17875 validation_1-logloss:0.33082
[27] validation_0-logloss:0.17167 validation_1-logloss:0.32675
[28] validation_0-logloss:0.16481 validation_1-logloss:0.32099
[29] validation_0-logloss:0.15835 validation_1-logloss:0.31671
[30] validation_0-logloss:0.15225 validation_1-logloss:0.31277
[31] validation_0-logloss:0.14650 validation_1-logloss:0.30882
[32] validation_0-logloss:0.14102 validation_1-logloss:0.30437
[33] validation_0-logloss:0.13590 validation_1-logloss:0.30103
[34] validation_0-logloss:0.13109 validation_1-logloss:0.29794
[35] validation_0-logloss:0.12647 validation_1-logloss:0.29499
[36] validation_0-logloss:0.12197 validation_1-logloss:0.29295
[37] validation_0-logloss:0.11784 validation_1-logloss:0.29043
[38] validation_0-logloss:0.11379 validation_1-logloss:0.28927
[39] validation_0-logloss:0.10994 validation_1-logloss:0.28578
[40] validation_0-logloss:0.10638 validation_1-logloss:0.28364
[41] validation_0-logloss:0.10302 validation_1-logloss:0.28183
[42] validation_0-logloss:0.09963 validation_1-logloss:0.28005
[43] validation_0-logloss:0.09649 validation_1-logloss:0.27972
[44] validation_0-logloss:0.09359 validation_1-logloss:0.27744
[45] validation_0-logloss:0.09080 validation_1-logloss:0.27542
[46] validation_0-logloss:0.08807 validation_1-logloss:0.27504
[47] validation_0-logloss:0.08541 validation_1-logloss:0.27458
[48] validation_0-logloss:0.08299 validation_1-logloss:0.27348
[49] validation_0-logloss:0.08035 validation_1-logloss:0.27247
[50] validation_0-logloss:0.07786 validation_1-logloss:0.27163
[51] validation_0-logloss:0.07550 validation_1-logloss:0.27094
[52] validation_0-logloss:0.07344 validation_1-logloss:0.26967
[53] validation_0-logloss:0.07147 validation_1-logloss:0.27008
[54] validation_0-logloss:0.06964 validation_1-logloss:0.26890
[55] validation_0-logloss:0.06766 validation_1-logloss:0.26854
[56] validation_0-logloss:0.06592 validation_1-logloss:0.26900
[57] validation_0-logloss:0.06433 validation_1-logloss:0.26790
[58] validation_0-logloss:0.06259 validation_1-logloss:0.26663
[59] validation_0-logloss:0.06107 validation_1-logloss:0.26743
[60] validation_0-logloss:0.05957 validation_1-logloss:0.26610
[61] validation_0-logloss:0.05817 validation_1-logloss:0.26644
[62] validation_0-logloss:0.05691 validation_1-logloss:0.26673
[63] validation_0-logloss:0.05550 validation_1-logloss:0.26550
[64] validation_0-logloss:0.05422 validation_1-logloss:0.26443
[65] validation_0-logloss:0.05311 validation_1-logloss:0.26500
[66] validation_0-logloss:0.05207 validation_1-logloss:0.26591
[67] validation_0-logloss:0.05093 validation_1-logloss:0.26501
[68] validation_0-logloss:0.04976 validation_1-logloss:0.26435
[69] validation_0-logloss:0.04872 validation_1-logloss:0.26360
[70] validation_0-logloss:0.04776 validation_1-logloss:0.26319
[71] validation_0-logloss:0.04680 validation_1-logloss:0.26255
[72] validation_0-logloss:0.04580 validation_1-logloss:0.26204
[73] validation_0-logloss:0.04484 validation_1-logloss:0.26254
[74] validation_0-logloss:0.04388 validation_1-logloss:0.26289
[75] validation_0-logloss:0.04309 validation_1-logloss:0.26249
[76] validation_0-logloss:0.04224 validation_1-logloss:0.26217
[77] validation_0-logloss:0.04133 validation_1-logloss:0.26166
[78] validation_0-logloss:0.04050 validation_1-logloss:0.26179
[79] validation_0-logloss:0.03967 validation_1-logloss:0.26103
[80] validation_0-logloss:0.03877 validation_1-logloss:0.26094
[81] validation_0-logloss:0.03806 validation_1-logloss:0.26148
[82] validation_0-logloss:0.03740 validation_1-logloss:0.26054
[83] validation_0-logloss:0.03676 validation_1-logloss:0.25967
[84] validation_0-logloss:0.03605 validation_1-logloss:0.25905
[85] validation_0-logloss:0.03545 validation_1-logloss:0.26007
[86] validation_0-logloss:0.03488 validation_1-logloss:0.25984
[87] validation_0-logloss:0.03425 validation_1-logloss:0.25933
[88] validation_0-logloss:0.03361 validation_1-logloss:0.25932
[89] validation_0-logloss:0.03311 validation_1-logloss:0.26002
[90] validation_0-logloss:0.03260 validation_1-logloss:0.25936
[91] validation_0-logloss:0.03202 validation_1-logloss:0.25886
[92] validation_0-logloss:0.03152 validation_1-logloss:0.25918
[93] validation_0-logloss:0.03107 validation_1-logloss:0.25865
[94] validation_0-logloss:0.03049 validation_1-logloss:0.25951
[95] validation_0-logloss:0.03007 validation_1-logloss:0.26091
[96] validation_0-logloss:0.02963 validation_1-logloss:0.26014
[97] validation_0-logloss:0.02913 validation_1-logloss:0.25974
[98] validation_0-logloss:0.02866 validation_1-logloss:0.25937
[99] validation_0-logloss:0.02829 validation_1-logloss:0.25893
[100] validation_0-logloss:0.02789 validation_1-logloss:0.25928
[101] validation_0-logloss:0.02751 validation_1-logloss:0.25955
[102] validation_0-logloss:0.02714 validation_1-logloss:0.25901
[103] validation_0-logloss:0.02668 validation_1-logloss:0.25991
[104] validation_0-logloss:0.02634 validation_1-logloss:0.25950
[105] validation_0-logloss:0.02594 validation_1-logloss:0.25924
[106] validation_0-logloss:0.02556 validation_1-logloss:0.25901
[107] validation_0-logloss:0.02522 validation_1-logloss:0.25738
[108] validation_0-logloss:0.02492 validation_1-logloss:0.25702
[109] validation_0-logloss:0.02453 validation_1-logloss:0.25789
[110] validation_0-logloss:0.02418 validation_1-logloss:0.25770
[111] validation_0-logloss:0.02384 validation_1-logloss:0.25842
[112] validation_0-logloss:0.02356 validation_1-logloss:0.25810
[113] validation_0-logloss:0.02322 validation_1-logloss:0.25848
[114] validation_0-logloss:0.02290 validation_1-logloss:0.25833
[115] validation_0-logloss:0.02260 validation_1-logloss:0.25820
[116] validation_0-logloss:0.02229 validation_1-logloss:0.25905
[117] validation_0-logloss:0.02204 validation_1-logloss:0.25878
[118] validation_0-logloss:0.02176 validation_1-logloss:0.25728
[119] validation_0-logloss:0.02149 validation_1-logloss:0.25722
[120] validation_0-logloss:0.02119 validation_1-logloss:0.25764
[121] validation_0-logloss:0.02095 validation_1-logloss:0.25761
[122] validation_0-logloss:0.02067 validation_1-logloss:0.25832
[123] validation_0-logloss:0.02045 validation_1-logloss:0.25808
[124] validation_0-logloss:0.02023 validation_1-logloss:0.25855
[125] validation_0-logloss:0.01998 validation_1-logloss:0.25714
[126] validation_0-logloss:0.01973 validation_1-logloss:0.25587
[127] validation_0-logloss:0.01946 validation_1-logloss:0.25640
[128] validation_0-logloss:0.01927 validation_1-logloss:0.25685
[129] validation_0-logloss:0.01908 validation_1-logloss:0.25665
[130] validation_0-logloss:0.01886 validation_1-logloss:0.25712
[131] validation_0-logloss:0.01863 validation_1-logloss:0.25609
[132] validation_0-logloss:0.01839 validation_1-logloss:0.25649
[133] validation_0-logloss:0.01816 validation_1-logloss:0.25789
[134] validation_0-logloss:0.01802 validation_1-logloss:0.25811
[135] validation_0-logloss:0.01785 validation_1-logloss:0.25794
[136] validation_0-logloss:0.01763 validation_1-logloss:0.25876
[137] validation_0-logloss:0.01748 validation_1-logloss:0.25884
[138] validation_0-logloss:0.01732 validation_1-logloss:0.25867
[139] validation_0-logloss:0.01719 validation_1-logloss:0.25876
[140] validation_0-logloss:0.01696 validation_1-logloss:0.25987
[141] validation_0-logloss:0.01681 validation_1-logloss:0.25960
[142] validation_0-logloss:0.01669 validation_1-logloss:0.25982
[143] validation_0-logloss:0.01656 validation_1-logloss:0.25992
[144] validation_0-logloss:0.01638 validation_1-logloss:0.26035
[145] validation_0-logloss:0.01623 validation_1-logloss:0.26055
[146] validation_0-logloss:0.01606 validation_1-logloss:0.26092
[147] validation_0-logloss:0.01589 validation_1-logloss:0.26137
[148] validation_0-logloss:0.01572 validation_1-logloss:0.25999
[149] validation_0-logloss:0.01557 validation_1-logloss:0.26028
[150] validation_0-logloss:0.01546 validation_1-logloss:0.26048
[151] validation_0-logloss:0.01531 validation_1-logloss:0.26142
[152] validation_0-logloss:0.01515 validation_1-logloss:0.26188
[153] validation_0-logloss:0.01501 validation_1-logloss:0.26227
[154] validation_0-logloss:0.01486 validation_1-logloss:0.26287
[155] validation_0-logloss:0.01476 validation_1-logloss:0.26299
[156] validation_0-logloss:0.01461 validation_1-logloss:0.26346
[157] validation_0-logloss:0.01448 validation_1-logloss:0.26379
[158] validation_0-logloss:0.01434 validation_1-logloss:0.26306
[159] validation_0-logloss:0.01424 validation_1-logloss:0.26237
[160] validation_0-logloss:0.01410 validation_1-logloss:0.26251
[161] validation_0-logloss:0.01401 validation_1-logloss:0.26265
[162] validation_0-logloss:0.01392 validation_1-logloss:0.26264
[163] validation_0-logloss:0.01380 validation_1-logloss:0.26250
[164] validation_0-logloss:0.01372 validation_1-logloss:0.26264
[165] validation_0-logloss:0.01359 validation_1-logloss:0.26255
[166] validation_0-logloss:0.01350 validation_1-logloss:0.26188
[167] validation_0-logloss:0.01342 validation_1-logloss:0.26203
[168] validation_0-logloss:0.01331 validation_1-logloss:0.26190
[169] validation_0-logloss:0.01319 validation_1-logloss:0.26184
[170] validation_0-logloss:0.01312 validation_1-logloss:0.26133
[171] validation_0-logloss:0.01304 validation_1-logloss:0.26148
[172] validation_0-logloss:0.01297 validation_1-logloss:0.26157
[173] validation_0-logloss:0.01285 validation_1-logloss:0.26253
[174] validation_0-logloss:0.01278 validation_1-logloss:0.26229
[175] validation_0-logloss:0.01267 validation_1-logloss:0.26086
get_clf_eval(y_test , ws50_preds, ws50_pred_proba)
# 오차 행렬
# [[34 3]
# [ 2 75]]
# 정확도: 0.9561, 정밀도: 0.9615, 재현율: 0.9740, F1: 0.9677, AUC:0.9933
early stopping을 10으로 설정하고 재 학습/예측/평가
early stopping을 너무 작게 하면 성능이 더 떨어질 수 있음
When?) HyperParameter를 빠르게 튜닝해야 할 때 -> 그 후 조금 더 Depth 있게 튜닝
# early_stopping_rounds를 10으로 설정하고 재 학습.
xgb_wrapper.fit(X_tr, y_tr, early_stopping_rounds=10,
eval_metric="logloss", eval_set=evals,verbose=True)
ws10_preds = xgb_wrapper.predict(X_test)
ws10_pred_proba = xgb_wrapper.predict_proba(X_test)[:, 1]
get_clf_eval(y_test , ws10_preds, ws10_pred_proba)
from xgboost import plot_importance
import matplotlib.pyplot as plt
%matplotlib inline
fig, ax = plt.subplots(figsize=(10, 12))
# 사이킷런 래퍼 클래스를 입력해도 무방.
plot_importance(xgb_wrapper, ax=ax)
xgb_wrapper.evals_result( )
- evals_result( )를 사용하여 .fit( ) 수행시 나오는 train_data_loss와 validation_data_loss에 접근가능
- results = xgb_wrapper.evals_result( )
- results['validation_0']['logloss']: 학습 데이터의 손실함수 값
- results['validation_1']['logloss']: 검증 데이터의 손실함수 값
import matplotlib.pyplot as plt
results = xgb_wrapper.evals_result()
plt.plot(results["validation_0"]["logloss"], label="train")
plt.plot(results["validation_1"]["logloss"], label="test")
plt.legend()
plt.show()
7. LightGBM
Gradient Boosting Framework로 Tree 기반 학습 알고리즘
XGBoost 대비 장점
- 더 빠른 학습과 예측 수행 시간
- 더 작은 메모리 사용량
- 카테고리형 피처의 자동 변환과 최적 분할
- (One-Hot Encoding등을 사용하지 않고도 카테고리형 피처를 최적으로 변환하고 이에 따른 노드 분할 수행)
LightBGM 트리 분할 방식
기존의 GBM(XGBoost포함) 계열들은 균형 트리 분할(Level Wise): depth를 최소화
LightGBM은 리프 중심 트리 분할(Leaf Wise):
- 특정 방향성을 가지고 리프 노드들을 계속 분할한다면 예측 오류를 줄일 수 있음
- 더 빠른 수행 성능과 향상된 예측이 가능하다
- 속도가 빠르다는 것이 가장 큰 장점.
-데이터 양이 많아지는 상황에서 빠른 결과를 얻는데 시간이 점점 많이 걸리고 있는데,
Light GBM은 큰 사이즈의 데이터를 다룰 수 있고 실행시킬 때 적은 메모리를 차지한다
LightBGM Python 구현
LightGBM HyperParameter
LightGBM Scikit Learn Wrapper는
XGBoost Scikit Learn Wrapper에 대해 해당 HyperParameter가 있으면 이를 그대로 사용하고
그렇지 않으면 Python Wrapper LightGBM HyperParameter를 사용
- num_leaves: 최대 leaf node 개수
- boost_from_average
- True일 경우 레이블 값이 극도로 불균형 분포를 이루는 경우 재현율 및 ROC-AUC 성능이 매우 저하됨. 레이블 값이 극도로 불균형할 경우 boost_from_average를 False로 설정하는 것이 유리
- LightGBM 2.1.0 이상 버젼에서 boost_from_average가 True가 Default가 됨
모델의 복잡도를 줄이는 기본 튜닝 방안
num_leaves(최대 리프노드의 개수)를 중심으로 min_child_samples(min_data_in_leaf, 리프노드가 될 수 있는 최소 데이터 건수), max_depth를 함께 조절하면서 모델의 복잡도를 줄이는 것이 기본 튜닝 방안
- num_leaves를 늘리면 정확도가 높아지지만 트리가 깊어지고 과접합되기 쉬움
- min_child_samples(min_data_in_leaf)를 크게 설정하면 트리가 깊어지는 것을 방지
- max_depth는 명시적으로 깊이를 제한. 위의 두 파라미터와 함꼐 과적합을 개선하는데 사용
또한, learning_rate을 줄이면서 n_estimator를 크게하는 것은 부스팅에서의 기본적인 튜닝 방안
LightGBM HyperParamter 튜닝 개요
Scikit Learn Wrapper LightGBM HyperParameter
Scikit Learn XGBoost의 HyperParameter를 먼저 따라가고,
만약 없다면 Python Wrapper LightGBM의 HyperParameter를 사용
8. LightGBM을 이용한 위스콘신 유방암 예측
LightGBM 적용 - 위스콘신 Breast Cancer Prediction
from lightgbm import LGBMClassifier
# LightGBM의 파이썬 패키지인 lightgbm에서 LGBMClassifier 임포트
from lightgbm import LGBMClassifier
import pandas as pd
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')
dataset = load_breast_cancer()
cancer_df = pd.DataFrame(data=dataset.data, columns=dataset.feature_names)
cancer_df['target']= dataset.target
cancer_df.head()
X_test, y_test: 테스트용 데이터의 피처 데이터셋, 레이블 값
X_train, y_train: 학습용 데이터의 피처 데이터셋, 레이블 값
X_tr, y_tr: X_train, y_train인 학습용 데이터를 다시 학습용 데이터로 분리한 피처 데이터셋, 레이블 값
X_val, y_val: X_train, y_train인 학습용 데이터를 다시 검증용 데이터로 분리한 피처 데이터셋, 레이블 값
X_features = cancer_df.iloc[:, :-1]
y_label = cancer_df.iloc[:, -1]
# 전체 데이터 중 80%는 학습용 데이터, 20%는 테스트용 데이터 추출
X_train, X_test, y_train, y_test=train_test_split(X_features, y_label,
test_size=0.2, random_state=156 )
# 위에서 만든 X_train, y_train을 다시 쪼개서 90%는 학습과 10%는 검증용 데이터로 분리
X_tr, X_val, y_tr, y_val= train_test_split(X_train, y_train,
test_size=0.1, random_state=156 )
# 앞서 XGBoost와 동일하게 n_estimators는 400 설정.
lgbm_wrapper = LGBMClassifier(n_estimators=400, learning_rate=0.05)
# LightGBM도 XGBoost와 동일하게 조기 중단 수행 가능.
evals = [(X_tr, y_tr), (X_val, y_val)]
lgbm_wrapper.fit(X_tr, y_tr, early_stopping_rounds=50, eval_metric="logloss",
eval_set=evals, verbose=True)
preds = lgbm_wrapper.predict(X_test)
pred_proba = lgbm_wrapper.predict_proba(X_test)[:, 1]
[fit() 함수의 HyperParameter]
- early_stopping_rounds: 더 이상 평가 지표가 감소하지 않는 최대 반복횟수
- eval_metric: 반복 수행 시 사용하는 비용 평가 지표
- eval_set: 평가를 수행하는 별도의 검증 데이터 세트(Validation Set)
- [(X_tr, y_tr), (X_val, y_val)] = [(학습용 데이터 셋), (검증용 데이터 셋)] 형태
Early Stopping
[61]이후로 평가지표가 더 이상 감소하지 않음 (early_stopping_rounds=50 이므로)
[1] training's binary_logloss: 0.625671 valid_1's binary_logloss: 0.628248
[2] training's binary_logloss: 0.588173 valid_1's binary_logloss: 0.601106
[3] training's binary_logloss: 0.554518 valid_1's binary_logloss: 0.577587
[4] training's binary_logloss: 0.523972 valid_1's binary_logloss: 0.556324
[5] training's binary_logloss: 0.49615 valid_1's binary_logloss: 0.537407
[6] training's binary_logloss: 0.470108 valid_1's binary_logloss: 0.519401
[7] training's binary_logloss: 0.446647 valid_1's binary_logloss: 0.502637
[8] training's binary_logloss: 0.425055 valid_1's binary_logloss: 0.488311
[9] training's binary_logloss: 0.405125 valid_1's binary_logloss: 0.474664
[10] training's binary_logloss: 0.386526 valid_1's binary_logloss: 0.461267
[11] training's binary_logloss: 0.367027 valid_1's binary_logloss: 0.444274
[12] training's binary_logloss: 0.350713 valid_1's binary_logloss: 0.432755
[13] training's binary_logloss: 0.334601 valid_1's binary_logloss: 0.421371
[14] training's binary_logloss: 0.319854 valid_1's binary_logloss: 0.411418
[15] training's binary_logloss: 0.306374 valid_1's binary_logloss: 0.402989
[16] training's binary_logloss: 0.293116 valid_1's binary_logloss: 0.393973
[17] training's binary_logloss: 0.280812 valid_1's binary_logloss: 0.384801
[18] training's binary_logloss: 0.268352 valid_1's binary_logloss: 0.376191
[19] training's binary_logloss: 0.256942 valid_1's binary_logloss: 0.368378
[20] training's binary_logloss: 0.246443 valid_1's binary_logloss: 0.362062
[21] training's binary_logloss: 0.236874 valid_1's binary_logloss: 0.355162
[22] training's binary_logloss: 0.227501 valid_1's binary_logloss: 0.348933
[23] training's binary_logloss: 0.218988 valid_1's binary_logloss: 0.342819
[24] training's binary_logloss: 0.210621 valid_1's binary_logloss: 0.337386
[25] training's binary_logloss: 0.202076 valid_1's binary_logloss: 0.331523
[26] training's binary_logloss: 0.194199 valid_1's binary_logloss: 0.326349
[27] training's binary_logloss: 0.187107 valid_1's binary_logloss: 0.322785
[28] training's binary_logloss: 0.180535 valid_1's binary_logloss: 0.317877
[29] training's binary_logloss: 0.173834 valid_1's binary_logloss: 0.313928
[30] training's binary_logloss: 0.167198 valid_1's binary_logloss: 0.310105
[31] training's binary_logloss: 0.161229 valid_1's binary_logloss: 0.307107
[32] training's binary_logloss: 0.155494 valid_1's binary_logloss: 0.303837
[33] training's binary_logloss: 0.149125 valid_1's binary_logloss: 0.300315
[34] training's binary_logloss: 0.144045 valid_1's binary_logloss: 0.297816
[35] training's binary_logloss: 0.139341 valid_1's binary_logloss: 0.295387
[36] training's binary_logloss: 0.134625 valid_1's binary_logloss: 0.293063
[37] training's binary_logloss: 0.129167 valid_1's binary_logloss: 0.289127
[38] training's binary_logloss: 0.12472 valid_1's binary_logloss: 0.288697
[39] training's binary_logloss: 0.11974 valid_1's binary_logloss: 0.28576
[40] training's binary_logloss: 0.115054 valid_1's binary_logloss: 0.282853
[41] training's binary_logloss: 0.110662 valid_1's binary_logloss: 0.279441
[42] training's binary_logloss: 0.106358 valid_1's binary_logloss: 0.28113
[43] training's binary_logloss: 0.102324 valid_1's binary_logloss: 0.279139
[44] training's binary_logloss: 0.0985699 valid_1's binary_logloss: 0.276465
[45] training's binary_logloss: 0.094858 valid_1's binary_logloss: 0.275946
[46] training's binary_logloss: 0.0912486 valid_1's binary_logloss: 0.272819
[47] training's binary_logloss: 0.0883115 valid_1's binary_logloss: 0.272306
[48] training's binary_logloss: 0.0849963 valid_1's binary_logloss: 0.270452
[49] training's binary_logloss: 0.0821742 valid_1's binary_logloss: 0.268671
[50] training's binary_logloss: 0.0789991 valid_1's binary_logloss: 0.267587
[51] training's binary_logloss: 0.0761072 valid_1's binary_logloss: 0.26626
[52] training's binary_logloss: 0.0732567 valid_1's binary_logloss: 0.265542
[53] training's binary_logloss: 0.0706388 valid_1's binary_logloss: 0.264547
[54] training's binary_logloss: 0.0683911 valid_1's binary_logloss: 0.26502
[55] training's binary_logloss: 0.0659347 valid_1's binary_logloss: 0.264388
[56] training's binary_logloss: 0.0636873 valid_1's binary_logloss: 0.263128
[57] training's binary_logloss: 0.0613354 valid_1's binary_logloss: 0.26231
[58] training's binary_logloss: 0.0591944 valid_1's binary_logloss: 0.262011
[59] training's binary_logloss: 0.057033 valid_1's binary_logloss: 0.261454
[60] training's binary_logloss: 0.0550801 valid_1's binary_logloss: 0.260746
[61] training's binary_logloss: 0.0532381 valid_1's binary_logloss: 0.260236
[62] training's binary_logloss: 0.0514074 valid_1's binary_logloss: 0.261586
[63] training's binary_logloss: 0.0494837 valid_1's binary_logloss: 0.261797
[64] training's binary_logloss: 0.0477826 valid_1's binary_logloss: 0.262533
[65] training's binary_logloss: 0.0460364 valid_1's binary_logloss: 0.263305
[66] training's binary_logloss: 0.0444552 valid_1's binary_logloss: 0.264072
[67] training's binary_logloss: 0.0427638 valid_1's binary_logloss: 0.266223
[68] training's binary_logloss: 0.0412449 valid_1's binary_logloss: 0.266817
[69] training's binary_logloss: 0.0398589 valid_1's binary_logloss: 0.267819
[70] training's binary_logloss: 0.0383095 valid_1's binary_logloss: 0.267484
[71] training's binary_logloss: 0.0368803 valid_1's binary_logloss: 0.270233
[72] training's binary_logloss: 0.0355637 valid_1's binary_logloss: 0.268442
[73] training's binary_logloss: 0.0341747 valid_1's binary_logloss: 0.26895
[74] training's binary_logloss: 0.0328302 valid_1's binary_logloss: 0.266958
[75] training's binary_logloss: 0.0317853 valid_1's binary_logloss: 0.268091
[76] training's binary_logloss: 0.0305626 valid_1's binary_logloss: 0.266419
[77] training's binary_logloss: 0.0295001 valid_1's binary_logloss: 0.268588
[78] training's binary_logloss: 0.0284699 valid_1's binary_logloss: 0.270964
[79] training's binary_logloss: 0.0273953 valid_1's binary_logloss: 0.270293
[80] training's binary_logloss: 0.0264668 valid_1's binary_logloss: 0.270523
[81] training's binary_logloss: 0.0254636 valid_1's binary_logloss: 0.270683
[82] training's binary_logloss: 0.0245911 valid_1's binary_logloss: 0.273187
[83] training's binary_logloss: 0.0236486 valid_1's binary_logloss: 0.275994
[84] training's binary_logloss: 0.0228047 valid_1's binary_logloss: 0.274053
[85] training's binary_logloss: 0.0221693 valid_1's binary_logloss: 0.273211
[86] training's binary_logloss: 0.0213043 valid_1's binary_logloss: 0.272626
[87] training's binary_logloss: 0.0203934 valid_1's binary_logloss: 0.27534
[88] training's binary_logloss: 0.0195552 valid_1's binary_logloss: 0.276228
[89] training's binary_logloss: 0.0188623 valid_1's binary_logloss: 0.27525
[90] training's binary_logloss: 0.0183664 valid_1's binary_logloss: 0.276485
[91] training's binary_logloss: 0.0176788 valid_1's binary_logloss: 0.277052
[92] training's binary_logloss: 0.0170059 valid_1's binary_logloss: 0.277686
[93] training's binary_logloss: 0.0164317 valid_1's binary_logloss: 0.275332
[94] training's binary_logloss: 0.015878 valid_1's binary_logloss: 0.276236
[95] training's binary_logloss: 0.0152959 valid_1's binary_logloss: 0.274538
[96] training's binary_logloss: 0.0147216 valid_1's binary_logloss: 0.275244
[97] training's binary_logloss: 0.0141758 valid_1's binary_logloss: 0.275829
[98] training's binary_logloss: 0.0136551 valid_1's binary_logloss: 0.276654
[99] training's binary_logloss: 0.0131585 valid_1's binary_logloss: 0.277859
[100] training's binary_logloss: 0.0126961 valid_1's binary_logloss: 0.279265
[101] training's binary_logloss: 0.0122421 valid_1's binary_logloss: 0.276695
[102] training's binary_logloss: 0.0118067 valid_1's binary_logloss: 0.278488
[103] training's binary_logloss: 0.0113994 valid_1's binary_logloss: 0.278932
[104] training's binary_logloss: 0.0109799 valid_1's binary_logloss: 0.280997
[105] training's binary_logloss: 0.0105953 valid_1's binary_logloss: 0.281454
[106] training's binary_logloss: 0.0102381 valid_1's binary_logloss: 0.282058
[107] training's binary_logloss: 0.00986714 valid_1's binary_logloss: 0.279275
[108] training's binary_logloss: 0.00950998 valid_1's binary_logloss: 0.281427
[109] training's binary_logloss: 0.00915965 valid_1's binary_logloss: 0.280752
[110] training's binary_logloss: 0.00882581 valid_1's binary_logloss: 0.282152
[111] training's binary_logloss: 0.00850714 valid_1's binary_logloss: 0.280894
데이터의 수가 적을 경우에는 LightGBM이 XGBoost보다는 성능이 조금 더 떨어질 수 있다
get_clf_eval(y_test, preds, pred_proba)
# 오차 행렬
# [[34 3]
# [ 2 75]]
# 정확도: 0.9561, 정밀도: 0.9615, 재현율: 0.9740, F1: 0.9677, AUC:0.9877
# plot_importance( )를 이용하여 feature 중요도 시각화
from lightgbm import plot_importance
import matplotlib.pyplot as plt
%matplotlib inline
fig, ax = plt.subplots(figsize=(10, 12))
plot_importance(lgbm_wrapper, ax=ax)
plt.savefig('lightgbm_feature_importance.tif', format='tif', dpi=300, bbox_inches='tight)
plt.show()
plot_metric, plot_tree 사용
from lightgbm import plot_metric, plot_tree
plot_metric(lgbm_wrapper, figsize=(6,4))
plot_tree(lgbm_wrapper, figsize=(6,4))
# 조기 중단 없이 학습을 수행한 경우
evals = [(X_tr, y_tr), (X_val, y_val)]
lgbm_wrapper.fit(X_tr, y_tr, eval_metric='logloss', eval_set=evals, verbose=True)
preds = lgbm_wrapper.predict(X_test)
pred_proba = lgbm_wrapper.predict_proba(X_test)[:,1]
plot_metric(lgbm_wrapper, figsize=(6,4))
plot_tree(lgbm_wrapper, figsize=(6,4))