728x90
1.Data Preprocessing
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import warnings
warnings.filterwarnings("ignore")
cust_df = pd.read_csv("./train_santander.csv", encoding="latin-1")
print("dataset shape:", cust_df.shape)
cust_df.head(3)
cust_df.info()
# <class 'pandas.core.frame.DataFrame'>
# RangeIndex: 76020 entries, 0 to 76019
# Columns: 371 entries, ID to TARGET
# dtypes: float64(111), int64(260)
# memory usage: 215.2 MB
- 0이 만족한 고객, 1이 불만족한 고객
print(cust_df["TARGET"].value_counts())
unsatisfied_cnt = cust_df[cust_df["TARGET"] == 1].TARGET.count()
# unsatisfied_cnt = cust_df[cust_df["TARGET"] == 1].shape[0] 써도 ok
total_cnt = cust_df.TARGET.count()
# total_cnt = cust_df.shape[0] 써도 ok
print("unsatisfied 비율은 {0:.2f}".format((unsatisfied_cnt / total_cnt)))
# 0 73012
# 1 3008
# Name: TARGET, dtype: int64
# unsatisfied 비율은 0.04
- 데이터의 분포도 확인
cust_df.describe( )
- -999999 라는 Garbage 값을 다른 값으로 대체
cust_df['var3'].value_counts()
# 2 74165
# 8 138
# -999999 116
# 9 110
# 3 108
# ...
# 231 1
# 188 1
# 168 1
# 135 1
# 87 1
# Name: var3, Length: 208, dtype: int64
- Feature Data Set, Label 분리 & 피처 값 대체 및 드롭
# var3 피처 값 대체 및 ID 피처 드롭
cust_df["var3"].replace(-999999, 2, inplace=True)
cust_df.drop(["ID"], axis=1, inplace=True)
# 피처 세트와 레이블 세트분리. 레이블 컬럼은 DataFrame의 맨 마지막에 위치해 컬럼 위치 -1로 분리
X_features = cust_df.iloc[:, :-1]
y_labels = cust_df.iloc[:, -1]
print("피처 데이터 shape:{0}".format(X_features.shape))
# 피처 데이터 shape:(76020, 369)
- train_test_split( )을 이용하여 train, test set으로 분리
- stratify = y_labels로 파라미터 추가해도 ok: y_label의 데이터 분포도를 train, test label 값으로 분리할 때 동일하게 적용
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X_features, y_labels,
test_size=0.2, random_state=0)
train_cnt = y_train.count()
test_cnt = y_test.count()
print('학습 세트 Shape:{0}, 테스트 세트 Shape:{1}'.format(X_train.shape , X_test.shape))
print(' 학습 세트 레이블 값 분포 비율')
print(y_train.value_counts()/train_cnt)
print('\n 테스트 세트 레이블 값 분포 비율')
print(y_test.value_counts()/test_cnt)
# 학습 세트 Shape:(60816, 369), 테스트 세트 Shape:(15204, 369)
# 학습 세트 레이블 값 분포 비율
# 0 0.960964
# 1 0.039036
# Name: TARGET, dtype: float64
# 테스트 세트 레이블 값 분포 비율
# 0 0.9583
# 1 0.0417
# Name: TARGET, dtype: float64
# X_train, y_train을 다시 학습과 검증 데이터 세트로 분리.
X_tr, X_val, y_tr, y_val = train_test_split(X_train, y_train,
test_size=0.3, random_state=0)
2.XGBoost Model 학습과 HyperParameter 튜닝
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score
# n_estimators는 500으로, learning_rate 0.05, random state는 예제 수행 시마다 동일 예측 결과를 위해 설정.
xgb_clf = XGBClassifier(n_estimators=500, learning_rate=0.05, random_state=156)
# 성능 평가 지표를 auc로, 조기 중단 파라미터는 100으로 설정하고 학습 수행.
xgb_clf.fit(
X_tr,
y_tr,
early_stopping_rounds=100,
eval_metric="auc",
eval_set=[(X_tr, y_tr), (X_val, y_val)],
)
xgb_roc_score = roc_auc_score(y_test, xgb_clf.predict_proba(X_test)[:, 1])
print("ROC AUC: {0:.4f}".format(xgb_roc_score))
[0] validation_0-auc:0.82179 validation_1-auc:0.80068
[1] validation_0-auc:0.82347 validation_1-auc:0.80523
[2] validation_0-auc:0.83178 validation_1-auc:0.81097
[3] validation_0-auc:0.83401 validation_1-auc:0.81091
[4] validation_0-auc:0.83443 validation_1-auc:0.81040
[5] validation_0-auc:0.83570 validation_1-auc:0.81089
[6] validation_0-auc:0.83597 validation_1-auc:0.81057
[7] validation_0-auc:0.83643 validation_1-auc:0.81082
[8] validation_0-auc:0.83682 validation_1-auc:0.81147
[9] validation_0-auc:0.83769 validation_1-auc:0.81188
[10] validation_0-auc:0.83770 validation_1-auc:0.81163
[11] validation_0-auc:0.83911 validation_1-auc:0.81355
[12] validation_0-auc:0.83976 validation_1-auc:0.81336
[13] validation_0-auc:0.84038 validation_1-auc:0.81365
[14] validation_0-auc:0.84176 validation_1-auc:0.81419
[15] validation_0-auc:0.84306 validation_1-auc:0.81586
[16] validation_0-auc:0.84343 validation_1-auc:0.81610
[17] validation_0-auc:0.84373 validation_1-auc:0.81584
[18] validation_0-auc:0.84542 validation_1-auc:0.81581
[19] validation_0-auc:0.84580 validation_1-auc:0.81622
[20] validation_0-auc:0.84656 validation_1-auc:0.81641
[21] validation_0-auc:0.84732 validation_1-auc:0.81740
[22] validation_0-auc:0.84769 validation_1-auc:0.81748
[23] validation_0-auc:0.84833 validation_1-auc:0.81744
[24] validation_0-auc:0.84836 validation_1-auc:0.81704
[25] validation_0-auc:0.84892 validation_1-auc:0.81650
[26] validation_0-auc:0.85114 validation_1-auc:0.81730
[27] validation_0-auc:0.85293 validation_1-auc:0.81843
[28] validation_0-auc:0.85335 validation_1-auc:0.81883
[29] validation_0-auc:0.85441 validation_1-auc:0.82072
[30] validation_0-auc:0.85604 validation_1-auc:0.82169
[31] validation_0-auc:0.85755 validation_1-auc:0.82256
[32] validation_0-auc:0.85830 validation_1-auc:0.82267
[33] validation_0-auc:0.85895 validation_1-auc:0.82321
[34] validation_0-auc:0.85964 validation_1-auc:0.82332
[35] validation_0-auc:0.85988 validation_1-auc:0.82408
[36] validation_0-auc:0.86037 validation_1-auc:0.82452
[37] validation_0-auc:0.86075 validation_1-auc:0.82466
[38] validation_0-auc:0.86171 validation_1-auc:0.82554
[39] validation_0-auc:0.86241 validation_1-auc:0.82582
[40] validation_0-auc:0.86277 validation_1-auc:0.82590
[41] validation_0-auc:0.86347 validation_1-auc:0.82628
[42] validation_0-auc:0.86401 validation_1-auc:0.82668
[43] validation_0-auc:0.86473 validation_1-auc:0.82686
[44] validation_0-auc:0.86523 validation_1-auc:0.82692
[45] validation_0-auc:0.86622 validation_1-auc:0.82774
[46] validation_0-auc:0.86721 validation_1-auc:0.82834
[47] validation_0-auc:0.86776 validation_1-auc:0.82831
[48] validation_0-auc:0.86837 validation_1-auc:0.82909
[49] validation_0-auc:0.86916 validation_1-auc:0.82997
[50] validation_0-auc:0.86939 validation_1-auc:0.82984
[51] validation_0-auc:0.87005 validation_1-auc:0.83006
[52] validation_0-auc:0.87057 validation_1-auc:0.83039
[53] validation_0-auc:0.87088 validation_1-auc:0.83066
[54] validation_0-auc:0.87145 validation_1-auc:0.83105
[55] validation_0-auc:0.87222 validation_1-auc:0.83139
[56] validation_0-auc:0.87285 validation_1-auc:0.83167
[57] validation_0-auc:0.87329 validation_1-auc:0.83175
[58] validation_0-auc:0.87375 validation_1-auc:0.83170
[59] validation_0-auc:0.87419 validation_1-auc:0.83196
[60] validation_0-auc:0.87465 validation_1-auc:0.83191
[61] validation_0-auc:0.87537 validation_1-auc:0.83186
[62] validation_0-auc:0.87612 validation_1-auc:0.83187
[63] validation_0-auc:0.87663 validation_1-auc:0.83192
[64] validation_0-auc:0.87751 validation_1-auc:0.83194
[65] validation_0-auc:0.87825 validation_1-auc:0.83218
[66] validation_0-auc:0.87880 validation_1-auc:0.83248
[67] validation_0-auc:0.87937 validation_1-auc:0.83241
[68] validation_0-auc:0.87980 validation_1-auc:0.83239
[69] validation_0-auc:0.88044 validation_1-auc:0.83249
[70] validation_0-auc:0.88105 validation_1-auc:0.83229
[71] validation_0-auc:0.88163 validation_1-auc:0.83261
[72] validation_0-auc:0.88212 validation_1-auc:0.83251
[73] validation_0-auc:0.88263 validation_1-auc:0.83257
[74] validation_0-auc:0.88311 validation_1-auc:0.83258
[75] validation_0-auc:0.88357 validation_1-auc:0.83255
[76] validation_0-auc:0.88411 validation_1-auc:0.83264
[77] validation_0-auc:0.88458 validation_1-auc:0.83241
[78] validation_0-auc:0.88490 validation_1-auc:0.83232
[79] validation_0-auc:0.88539 validation_1-auc:0.83249
[80] validation_0-auc:0.88609 validation_1-auc:0.83251
[81] validation_0-auc:0.88656 validation_1-auc:0.83261
[82] validation_0-auc:0.88688 validation_1-auc:0.83265
[83] validation_0-auc:0.88735 validation_1-auc:0.83258
[84] validation_0-auc:0.88770 validation_1-auc:0.83270
[85] validation_0-auc:0.88825 validation_1-auc:0.83251
[86] validation_0-auc:0.88890 validation_1-auc:0.83240
[87] validation_0-auc:0.88928 validation_1-auc:0.83229
[88] validation_0-auc:0.88958 validation_1-auc:0.83225
[89] validation_0-auc:0.88997 validation_1-auc:0.83205
[90] validation_0-auc:0.89030 validation_1-auc:0.83210
[91] validation_0-auc:0.89067 validation_1-auc:0.83223
[92] validation_0-auc:0.89098 validation_1-auc:0.83212
[93] validation_0-auc:0.89125 validation_1-auc:0.83198
[94] validation_0-auc:0.89166 validation_1-auc:0.83198
[95] validation_0-auc:0.89191 validation_1-auc:0.83196
[96] validation_0-auc:0.89219 validation_1-auc:0.83181
[97] validation_0-auc:0.89253 validation_1-auc:0.83184
[98] validation_0-auc:0.89286 validation_1-auc:0.83176
[99] validation_0-auc:0.89310 validation_1-auc:0.83184
[100] validation_0-auc:0.89337 validation_1-auc:0.83176
[101] validation_0-auc:0.89375 validation_1-auc:0.83168
[102] validation_0-auc:0.89392 validation_1-auc:0.83175
[103] validation_0-auc:0.89418 validation_1-auc:0.83162
[104] validation_0-auc:0.89446 validation_1-auc:0.83162
[105] validation_0-auc:0.89483 validation_1-auc:0.83173
[106] validation_0-auc:0.89530 validation_1-auc:0.83182
[107] validation_0-auc:0.89551 validation_1-auc:0.83179
[108] validation_0-auc:0.89576 validation_1-auc:0.83190
[109] validation_0-auc:0.89621 validation_1-auc:0.83190
[110] validation_0-auc:0.89631 validation_1-auc:0.83198
[111] validation_0-auc:0.89645 validation_1-auc:0.83200
[112] validation_0-auc:0.89657 validation_1-auc:0.83214
[113] validation_0-auc:0.89690 validation_1-auc:0.83221
[114] validation_0-auc:0.89720 validation_1-auc:0.83223
[115] validation_0-auc:0.89735 validation_1-auc:0.83218
[116] validation_0-auc:0.89769 validation_1-auc:0.83236
[117] validation_0-auc:0.89799 validation_1-auc:0.83245
[118] validation_0-auc:0.89851 validation_1-auc:0.83253
[119] validation_0-auc:0.89868 validation_1-auc:0.83256
[120] validation_0-auc:0.89875 validation_1-auc:0.83269
[121] validation_0-auc:0.89930 validation_1-auc:0.83258
[122] validation_0-auc:0.89939 validation_1-auc:0.83269
[123] validation_0-auc:0.89989 validation_1-auc:0.83274
[124] validation_0-auc:0.90029 validation_1-auc:0.83308
[125] validation_0-auc:0.90067 validation_1-auc:0.83324
[126] validation_0-auc:0.90115 validation_1-auc:0.83325
[127] validation_0-auc:0.90124 validation_1-auc:0.83324
[128] validation_0-auc:0.90147 validation_1-auc:0.83324
[129] validation_0-auc:0.90185 validation_1-auc:0.83337
[130] validation_0-auc:0.90211 validation_1-auc:0.83336
[131] validation_0-auc:0.90236 validation_1-auc:0.83341
[132] validation_0-auc:0.90266 validation_1-auc:0.83351
[133] validation_0-auc:0.90290 validation_1-auc:0.83346
[134] validation_0-auc:0.90299 validation_1-auc:0.83362
[135] validation_0-auc:0.90310 validation_1-auc:0.83363
[136] validation_0-auc:0.90327 validation_1-auc:0.83355
[137] validation_0-auc:0.90330 validation_1-auc:0.83354
[138] validation_0-auc:0.90337 validation_1-auc:0.83357
[139] validation_0-auc:0.90353 validation_1-auc:0.83348
[140] validation_0-auc:0.90363 validation_1-auc:0.83353
[141] validation_0-auc:0.90371 validation_1-auc:0.83346
[142] validation_0-auc:0.90395 validation_1-auc:0.83341
[143] validation_0-auc:0.90397 validation_1-auc:0.83341
[144] validation_0-auc:0.90422 validation_1-auc:0.83340
[145] validation_0-auc:0.90446 validation_1-auc:0.83335
[146] validation_0-auc:0.90467 validation_1-auc:0.83354
[147] validation_0-auc:0.90482 validation_1-auc:0.83355
[148] validation_0-auc:0.90484 validation_1-auc:0.83356
[149] validation_0-auc:0.90500 validation_1-auc:0.83353
[150] validation_0-auc:0.90513 validation_1-auc:0.83352
[151] validation_0-auc:0.90531 validation_1-auc:0.83349
[152] validation_0-auc:0.90548 validation_1-auc:0.83349
[153] validation_0-auc:0.90551 validation_1-auc:0.83351
[154] validation_0-auc:0.90563 validation_1-auc:0.83348
[155] validation_0-auc:0.90572 validation_1-auc:0.83343
[156] validation_0-auc:0.90579 validation_1-auc:0.83344
[157] validation_0-auc:0.90591 validation_1-auc:0.83345
[158] validation_0-auc:0.90621 validation_1-auc:0.83355
[159] validation_0-auc:0.90630 validation_1-auc:0.83359
[160] validation_0-auc:0.90643 validation_1-auc:0.83356
[161] validation_0-auc:0.90657 validation_1-auc:0.83352
[162] validation_0-auc:0.90672 validation_1-auc:0.83346
[163] validation_0-auc:0.90694 validation_1-auc:0.83347
[164] validation_0-auc:0.90714 validation_1-auc:0.83343
[165] validation_0-auc:0.90720 validation_1-auc:0.83343
[166] validation_0-auc:0.90727 validation_1-auc:0.83337
[167] validation_0-auc:0.90730 validation_1-auc:0.83335
[168] validation_0-auc:0.90737 validation_1-auc:0.83333
[169] validation_0-auc:0.90741 validation_1-auc:0.83337
[170] validation_0-auc:0.90772 validation_1-auc:0.83335
[171] validation_0-auc:0.90778 validation_1-auc:0.83332
[172] validation_0-auc:0.90781 validation_1-auc:0.83337
[173] validation_0-auc:0.90786 validation_1-auc:0.83337
[174] validation_0-auc:0.90797 validation_1-auc:0.83326
[175] validation_0-auc:0.90802 validation_1-auc:0.83334
[176] validation_0-auc:0.90818 validation_1-auc:0.83335
[177] validation_0-auc:0.90832 validation_1-auc:0.83330
[178] validation_0-auc:0.90836 validation_1-auc:0.83332
[179] validation_0-auc:0.90850 validation_1-auc:0.83339
[180] validation_0-auc:0.90856 validation_1-auc:0.83340
[181] validation_0-auc:0.90864 validation_1-auc:0.83338
[182] validation_0-auc:0.90890 validation_1-auc:0.83335
[183] validation_0-auc:0.90899 validation_1-auc:0.83327
[184] validation_0-auc:0.90902 validation_1-auc:0.83330
[185] validation_0-auc:0.90913 validation_1-auc:0.83330
[186] validation_0-auc:0.90934 validation_1-auc:0.83352
[187] validation_0-auc:0.90939 validation_1-auc:0.83356
[188] validation_0-auc:0.90947 validation_1-auc:0.83346
[189] validation_0-auc:0.90955 validation_1-auc:0.83347
[190] validation_0-auc:0.90978 validation_1-auc:0.83341
[191] validation_0-auc:0.90982 validation_1-auc:0.83340
[192] validation_0-auc:0.90987 validation_1-auc:0.83342
[193] validation_0-auc:0.90999 validation_1-auc:0.83339
[194] validation_0-auc:0.91010 validation_1-auc:0.83338
[195] validation_0-auc:0.91015 validation_1-auc:0.83329
[196] validation_0-auc:0.91018 validation_1-auc:0.83331
[197] validation_0-auc:0.91021 validation_1-auc:0.83333
[198] validation_0-auc:0.91029 validation_1-auc:0.83338
[199] validation_0-auc:0.91050 validation_1-auc:0.83338
[200] validation_0-auc:0.91056 validation_1-auc:0.83338
[201] validation_0-auc:0.91061 validation_1-auc:0.83334
[202] validation_0-auc:0.91065 validation_1-auc:0.83333
[203] validation_0-auc:0.91068 validation_1-auc:0.83334
[204] validation_0-auc:0.91079 validation_1-auc:0.83336
[205] validation_0-auc:0.91091 validation_1-auc:0.83323
[206] validation_0-auc:0.91098 validation_1-auc:0.83316
[207] validation_0-auc:0.91119 validation_1-auc:0.83319
[208] validation_0-auc:0.91131 validation_1-auc:0.83320
[209] validation_0-auc:0.91160 validation_1-auc:0.83331
[210] validation_0-auc:0.91165 validation_1-auc:0.83329
[211] validation_0-auc:0.91182 validation_1-auc:0.83325
[212] validation_0-auc:0.91189 validation_1-auc:0.83330
[213] validation_0-auc:0.91198 validation_1-auc:0.83329
[214] validation_0-auc:0.91219 validation_1-auc:0.83321
[215] validation_0-auc:0.91225 validation_1-auc:0.83326
[216] validation_0-auc:0.91228 validation_1-auc:0.83329
[217] validation_0-auc:0.91234 validation_1-auc:0.83326
[218] validation_0-auc:0.91258 validation_1-auc:0.83311
[219] validation_0-auc:0.91269 validation_1-auc:0.83311
[220] validation_0-auc:0.91279 validation_1-auc:0.83305
[221] validation_0-auc:0.91301 validation_1-auc:0.83310
[222] validation_0-auc:0.91311 validation_1-auc:0.83307
[223] validation_0-auc:0.91318 validation_1-auc:0.83310
[224] validation_0-auc:0.91333 validation_1-auc:0.83308
[225] validation_0-auc:0.91340 validation_1-auc:0.83309
[226] validation_0-auc:0.91340 validation_1-auc:0.83307
[227] validation_0-auc:0.91357 validation_1-auc:0.83310
[228] validation_0-auc:0.91360 validation_1-auc:0.83310
[229] validation_0-auc:0.91367 validation_1-auc:0.83312
[230] validation_0-auc:0.91378 validation_1-auc:0.83313
[231] validation_0-auc:0.91382 validation_1-auc:0.83313
[232] validation_0-auc:0.91383 validation_1-auc:0.83314
[233] validation_0-auc:0.91405 validation_1-auc:0.83312
[234] validation_0-auc:0.91414 validation_1-auc:0.83307
[235] validation_0-auc:0.91415 validation_1-auc:0.83305
ROC AUC: 0.8429
- XGBoost HyperParameter 점검
- colsample_bytree: Tree 생성에 필요한 피처들을 임의로 샘플링하는 비율
- min_child_weight: 노드를 분할하기 위해 필요한 weight의 총합, default=1이며 클수록 분할을 자제
- Search Space 생성
from hyperopt import hp
# max_depth는 5에서 15까지 1간격으로, min_child_weight는 1에서 6까지 1간격으로
# colsample_bytree는 0.5에서 0.95사이, learning_rate는 0.01에서 0.2사이 정규 분포된 값으로 검색.
xgb_search_space = {
"max_depth": hp.quniform("max_depth", 5, 15, 1),
"min_child_weight": hp.quniform("min_child_weight", 1, 6, 1),
"colsample_bytree": hp.uniform("colsample_bytree", 0.5, 0.95),
"learning_rate": hp.uniform("learning_rate", 0.01, 0.2),
}
- cross_val_score의 경우 early stopping이 불가능하기 때문에 KFold 객체를 생성하여 K-Fold 사용
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
# 목적 함수 설정.
# 추후 fmin()에서 입력된 search_space값으로 XGBClassifier 교차 검증 학습 후 -1* roc_auc 평균 값을 반환.
def objective_func(search_space):
xgb_clf = XGBClassifier(
n_estimators=100,
max_depth=int(search_space["max_depth"]),
min_child_weight=int(search_space["min_child_weight"]),
colsample_bytree=search_space["colsample_bytree"],
learning_rate=search_space["learning_rate"],
)
# 3개 k-fold 방식으로 평가된 roc_auc 지표를 담는 list
roc_auc_list = []
# 3개 k-fold방식 적용
kf = KFold(n_splits=3)
# X_train을 다시 학습과 검증용 데이터로 분리
for tr_index, val_index in kf.split(X_train):
# kf.split(X_train)으로 추출된 학습과 검증 index값으로 학습과 검증 데이터 세트 분리
X_tr, y_tr = X_train.iloc[tr_index], y_train.iloc[tr_index]
X_val, y_val = X_train.iloc[val_index], y_train.iloc[val_index]
# early stopping은 30회로 설정하고 추출된 학습과 검증 데이터로 XGBClassifier 학습 수행.
xgb_clf.fit(
X_tr,
y_tr,
early_stopping_rounds=30,
eval_metric="auc",
eval_set=[(X_tr, y_tr), (X_val, y_val)],
)
# 1로 예측한 확률값 추출후 roc auc 계산하고 평균 roc auc 계산을 위해 list에 결과값 담음.
score = roc_auc_score(y_val, xgb_clf.predict_proba(X_val)[:, 1])
roc_auc_list.append(score)
# 3개 k-fold로 계산된 roc_auc값의 평균값을 반환하되,
# HyperOpt는 목적함수의 최소값을 위한 입력값을 찾으므로 -1을 곱한 뒤 반환.
return (-1) * np.mean(roc_auc_list)
- 목적 함수의 최소값을 찾는 함수
from hyperopt import fmin, tpe, Trials
trials = Trials()
# fmin()함수를 호출. max_evals지정된 횟수만큼 반복 후 목적함수의 최소값을 가지는 최적 입력값 추출.
best = fmin(
fn=objective_func,
space=xgb_search_space,
algo=tpe.suggest,
trials=trials,
rstate=np.random.default_rng(seed=30),
max_evals=50, # 최대 반복 횟수를 지정합니다.
)
print("best:", best)
- Bayesian Optimization 결과
100%|██████████| 50/50 [54:28<00:00, 65.36s/trial, best loss: -0.8375277185394956]
best: {'colsample_bytree': 0.6511149462012106, 'learning_rate': 0.16991129737205532, 'max_depth': 5.0, 'min_child_weight': 4.0}
- HyperParamter 튜닝
# n_estimators를 500증가 후 최적으로 찾은 하이퍼 파라미터를 기반으로 학습과 예측 수행.
xgb_clf = XGBClassifier(
n_estimators=500,
learning_rate=round(best["learning_rate"], 5),
max_depth=int(best["max_depth"]),
min_child_weight=int(best["min_child_weight"]),
colsample_bytree=round(best["colsample_bytree"], 5),
)
# evaluation metric을 auc로, early stopping은 100 으로 설정하고 학습 수행.
xgb_clf.fit(
X_tr,
y_tr,
early_stopping_rounds=100,
eval_metric="auc",
eval_set=[(X_tr, y_tr), (X_val, y_val)],
)
xgb_roc_score = roc_auc_score(y_test, xgb_clf.predict_proba(X_test)[:, 1])
print("ROC AUC: {0:.4f}".format(xgb_roc_score))
[0] validation_0-auc:0.73335 validation_1-auc:0.71651
[1] validation_0-auc:0.75752 validation_1-auc:0.72910
[2] validation_0-auc:0.81737 validation_1-auc:0.79861
[3] validation_0-auc:0.82980 validation_1-auc:0.81081
[4] validation_0-auc:0.83890 validation_1-auc:0.81679
[5] validation_0-auc:0.84152 validation_1-auc:0.81826
[6] validation_0-auc:0.83828 validation_1-auc:0.81667
[7] validation_0-auc:0.84259 validation_1-auc:0.81949
[8] validation_0-auc:0.84295 validation_1-auc:0.81607
[9] validation_0-auc:0.84051 validation_1-auc:0.81218
[10] validation_0-auc:0.84692 validation_1-auc:0.81862
[11] validation_0-auc:0.84980 validation_1-auc:0.82227
[12] validation_0-auc:0.85326 validation_1-auc:0.82652
[13] validation_0-auc:0.85186 validation_1-auc:0.82427
[14] validation_0-auc:0.85543 validation_1-auc:0.82727
[15] validation_0-auc:0.85830 validation_1-auc:0.83023
[16] validation_0-auc:0.85930 validation_1-auc:0.83045
[17] validation_0-auc:0.86011 validation_1-auc:0.83259
[18] validation_0-auc:0.86132 validation_1-auc:0.83385
[19] validation_0-auc:0.86241 validation_1-auc:0.83445
[20] validation_0-auc:0.86318 validation_1-auc:0.83346
[21] validation_0-auc:0.86442 validation_1-auc:0.83479
[22] validation_0-auc:0.86593 validation_1-auc:0.83577
[23] validation_0-auc:0.86679 validation_1-auc:0.83557
[24] validation_0-auc:0.86790 validation_1-auc:0.83491
[25] validation_0-auc:0.86855 validation_1-auc:0.83522
[26] validation_0-auc:0.87083 validation_1-auc:0.83479
[27] validation_0-auc:0.87156 validation_1-auc:0.83495
[28] validation_0-auc:0.87218 validation_1-auc:0.83501
[29] validation_0-auc:0.87300 validation_1-auc:0.83476
[30] validation_0-auc:0.87378 validation_1-auc:0.83460
[31] validation_0-auc:0.87427 validation_1-auc:0.83458
[32] validation_0-auc:0.87508 validation_1-auc:0.83461
[33] validation_0-auc:0.87574 validation_1-auc:0.83450
[34] validation_0-auc:0.87626 validation_1-auc:0.83442
[35] validation_0-auc:0.87726 validation_1-auc:0.83441
[36] validation_0-auc:0.87805 validation_1-auc:0.83474
[37] validation_0-auc:0.87830 validation_1-auc:0.83498
[38] validation_0-auc:0.87867 validation_1-auc:0.83487
[39] validation_0-auc:0.87981 validation_1-auc:0.83480
[40] validation_0-auc:0.87997 validation_1-auc:0.83475
[41] validation_0-auc:0.88068 validation_1-auc:0.83456
[42] validation_0-auc:0.88098 validation_1-auc:0.83491
[43] validation_0-auc:0.88135 validation_1-auc:0.83519
[44] validation_0-auc:0.88199 validation_1-auc:0.83505
[45] validation_0-auc:0.88250 validation_1-auc:0.83520
[46] validation_0-auc:0.88270 validation_1-auc:0.83523
[47] validation_0-auc:0.88318 validation_1-auc:0.83528
[48] validation_0-auc:0.88335 validation_1-auc:0.83555
[49] validation_0-auc:0.88353 validation_1-auc:0.83539
[50] validation_0-auc:0.88413 validation_1-auc:0.83520
[51] validation_0-auc:0.88491 validation_1-auc:0.83550
[52] validation_0-auc:0.88559 validation_1-auc:0.83556
[53] validation_0-auc:0.88581 validation_1-auc:0.83540
[54] validation_0-auc:0.88595 validation_1-auc:0.83539
[55] validation_0-auc:0.88646 validation_1-auc:0.83531
[56] validation_0-auc:0.88665 validation_1-auc:0.83543
[57] validation_0-auc:0.88684 validation_1-auc:0.83533
[58] validation_0-auc:0.88698 validation_1-auc:0.83538
[59] validation_0-auc:0.88724 validation_1-auc:0.83526
[60] validation_0-auc:0.88731 validation_1-auc:0.83513
[61] validation_0-auc:0.88789 validation_1-auc:0.83515
[62] validation_0-auc:0.88821 validation_1-auc:0.83505
[63] validation_0-auc:0.88878 validation_1-auc:0.83496
[64] validation_0-auc:0.88949 validation_1-auc:0.83483
[65] validation_0-auc:0.88961 validation_1-auc:0.83468
[66] validation_0-auc:0.89016 validation_1-auc:0.83453
[67] validation_0-auc:0.89037 validation_1-auc:0.83465
[68] validation_0-auc:0.89052 validation_1-auc:0.83461
[69] validation_0-auc:0.89058 validation_1-auc:0.83449
[70] validation_0-auc:0.89069 validation_1-auc:0.83447
[71] validation_0-auc:0.89091 validation_1-auc:0.83441
[72] validation_0-auc:0.89175 validation_1-auc:0.83388
[73] validation_0-auc:0.89229 validation_1-auc:0.83375
[74] validation_0-auc:0.89278 validation_1-auc:0.83366
[75] validation_0-auc:0.89309 validation_1-auc:0.83369
[76] validation_0-auc:0.89320 validation_1-auc:0.83359
[77] validation_0-auc:0.89332 validation_1-auc:0.83349
[78] validation_0-auc:0.89400 validation_1-auc:0.83390
[79] validation_0-auc:0.89409 validation_1-auc:0.83388
[80] validation_0-auc:0.89423 validation_1-auc:0.83369
[81] validation_0-auc:0.89494 validation_1-auc:0.83392
[82] validation_0-auc:0.89510 validation_1-auc:0.83417
[83] validation_0-auc:0.89521 validation_1-auc:0.83414
[84] validation_0-auc:0.89549 validation_1-auc:0.83389
[85] validation_0-auc:0.89601 validation_1-auc:0.83414
[86] validation_0-auc:0.89666 validation_1-auc:0.83400
[87] validation_0-auc:0.89723 validation_1-auc:0.83375
[88] validation_0-auc:0.89738 validation_1-auc:0.83339
[89] validation_0-auc:0.89751 validation_1-auc:0.83331
[90] validation_0-auc:0.89761 validation_1-auc:0.83317
[91] validation_0-auc:0.89771 validation_1-auc:0.83321
[92] validation_0-auc:0.89782 validation_1-auc:0.83300
[93] validation_0-auc:0.89792 validation_1-auc:0.83316
[94] validation_0-auc:0.89822 validation_1-auc:0.83314
[95] validation_0-auc:0.89832 validation_1-auc:0.83312
[96] validation_0-auc:0.89840 validation_1-auc:0.83334
[97] validation_0-auc:0.89842 validation_1-auc:0.83328
[98] validation_0-auc:0.89901 validation_1-auc:0.83376
[99] validation_0-auc:0.89915 validation_1-auc:0.83370
[100] validation_0-auc:0.89920 validation_1-auc:0.83366
[101] validation_0-auc:0.89933 validation_1-auc:0.83377
[102] validation_0-auc:0.89963 validation_1-auc:0.83349
[103] validation_0-auc:0.90008 validation_1-auc:0.83355
[104] validation_0-auc:0.90071 validation_1-auc:0.83365
[105] validation_0-auc:0.90125 validation_1-auc:0.83353
[106] validation_0-auc:0.90137 validation_1-auc:0.83339
[107] validation_0-auc:0.90162 validation_1-auc:0.83289
[108] validation_0-auc:0.90190 validation_1-auc:0.83259
[109] validation_0-auc:0.90214 validation_1-auc:0.83243
[110] validation_0-auc:0.90249 validation_1-auc:0.83218
[111] validation_0-auc:0.90289 validation_1-auc:0.83211
[112] validation_0-auc:0.90332 validation_1-auc:0.83193
[113] validation_0-auc:0.90416 validation_1-auc:0.83209
[114] validation_0-auc:0.90444 validation_1-auc:0.83212
[115] validation_0-auc:0.90491 validation_1-auc:0.83252
[116] validation_0-auc:0.90547 validation_1-auc:0.83218
[117] validation_0-auc:0.90560 validation_1-auc:0.83207
[118] validation_0-auc:0.90592 validation_1-auc:0.83183
[119] validation_0-auc:0.90594 validation_1-auc:0.83177
[120] validation_0-auc:0.90618 validation_1-auc:0.83153
[121] validation_0-auc:0.90628 validation_1-auc:0.83157
ROC AUC: 0.8443
- plot_importance 시각화
from xgboost import plot_importance
import matplotlib.pyplot as plt
%matplotlib inline
fig, ax = plt.subplots(1,1,figsize=(10,8))
plot_importance(xgb_clf, ax=ax , max_num_features=20,height=0.4)
3.LightGBM 학습과 HyperParameter 튜닝
LightGBM HyperParameter
- num_leaves: 최대 리프노드의 개수
- min_child_samples(min_data_in_leaf): 리프 노드가 될 수 있는 최소 데이터 건수(sample 수)
- max_depth: 트리의 최대 깊이를 규정
- num_leaves의 개수를 중심으로 min_child_samples(min_data_in_leaf), max_depth를 함께 조정하면서 모델의 복잡도를 줄이는 것이 기본 튜닝 방안
- subsample: 트리가 커져서 과적합되는 것을 제어하기 위해 데이터를 샘플링하는 비율
- colsample_bytree: 트리 생성에 필요한 피처(컬럼)을 임의로 샘플링하는 데 사용
LightGBM HyperParameter 튜닝 개요
- 너무 많은 하이퍼 파라미터들을 튜닝하려는 것은 오히려 최적값을 찾는데 방해가 될 수 있음
- 적당한 수준의 하이퍼 파라미터 개수 설정 필요
LightGBM 실습
- 하이퍼 파라미터를 튜닝하지 않은 경우
from lightgbm import LGBMClassifier
lgbm_clf = LGBMClassifier(n_estimators=500)
eval_set = [(X_tr, y_tr), (X_val, y_val)]
lgbm_clf.fit(
X_tr, y_tr, early_stopping_rounds=100, eval_metric="auc", eval_set=eval_set
)
lgbm_roc_score = roc_auc_score(y_test, lgbm_clf.predict_proba(X_test)[:, 1])
print("ROC AUC: {0:.4f}".format(lgbm_roc_score))
# ROC AUC: 0.8384
- Search Space 설정: num_leaves를 중심으로 과적합 제어
lgbm_search_space = {'num_leaves': hp.quniform('num_leaves', 32, 64, 1),
'max_depth': hp.quniform('max_depth', 100, 160, 1),
'min_child_samples': hp.quniform('min_child_samples', 60, 100, 1),
'subsample': hp.uniform('subsample', 0.7, 1),
'learning_rate': hp.uniform('learning_rate', 0.01, 0.2)
}
- 목적 함수 설정
def objective_func(search_space):
lgbm_clf = LGBMClassifier(
n_estimators=100,
num_leaves=int(search_space["num_leaves"]),
max_depth=int(search_space["max_depth"]),
min_child_samples=int(search_space["min_child_samples"]),
subsample=search_space["subsample"],
learning_rate=search_space["learning_rate"],
)
# 3개 k-fold 방식으로 평가된 roc_auc 지표를 담는 list
roc_auc_list = []
# 3개 k-fold방식 적용
kf = KFold(n_splits=3)
# X_train을 다시 학습과 검증용 데이터로 분리
for tr_index, val_index in kf.split(X_train):
# kf.split(X_train)으로 추출된 학습과 검증 index값으로 학습과 검증 데이터 세트 분리
X_tr, y_tr = X_train.iloc[tr_index], y_train.iloc[tr_index]
X_val, y_val = X_train.iloc[val_index], y_train.iloc[val_index]
# early stopping은 30회로 설정하고 추출된 학습과 검증 데이터로 XGBClassifier 학습 수행.
lgbm_clf.fit(
X_tr,
y_tr,
early_stopping_rounds=30,
eval_metric="auc",
eval_set=[(X_tr, y_tr), (X_val, y_val)],
)
# 1로 예측한 확률값 추출후 roc auc 계산하고 평균 roc auc 계산을 위해 list에 결과값 담음.
score = roc_auc_score(y_val, lgbm_clf.predict_proba(X_val)[:, 1])
roc_auc_list.append(score)
# 3개 k-fold로 계산된 roc_auc값의 평균값을 반환하되,
# HyperOpt는 목적함수의 최소값을 위한 입력값을 찾으므로 -1을 곱한 뒤 반환.
return -1 * np.mean(roc_auc_list)
- 목적함수의 최소값을 찾는 함수
from hyperopt import fmin, tpe, Trials
trials = Trials()
# fmin()함수를 호출. max_evals지정된 횟수만큼 반복 후 목적함수의 최소값을 가지는 최적 입력값 추출.
best = fmin(
fn=objective_func,
space=lgbm_search_space,
algo=tpe.suggest,
max_evals=50, # 최대 반복 횟수를 지정합니다.
trials=trials,
rstate=np.random.default_rng(seed=30),
)
print("best:", best)
100%|██████████| 50/50 [05:27<00:00, 6.55s/trial, best loss: -0.8357657786434084]
best: {'learning_rate': 0.08592271133758617, 'max_depth': 121.0, 'min_child_samples': 69.0, 'num_leaves': 41.0, 'subsample': 0.9148958093027029}
- 최적 HyperParameter로 재학습
lgbm_clf = LGBMClassifier(
n_estimators=500,
num_leaves=int(best["num_leaves"]),
max_depth=int(best["max_depth"]),
min_child_samples=int(best["min_child_samples"]),
subsample=round(best["subsample"], 5),
learning_rate=round(best["learning_rate"], 5),
)
# evaluation metric을 auc로, early stopping은 100 으로 설정하고 학습 수행.
lgbm_clf.fit(
X_tr,
y_tr,
early_stopping_rounds=100,
eval_metric="auc",
eval_set=[(X_tr, y_tr), (X_val, y_val)],
)
lgbm_roc_score = roc_auc_score(y_test, lgbm_clf.predict_proba(X_test)[:, 1])
print("ROC AUC: {0:.4f}".format(lgbm_roc_score))
# ROC AUC: 0.8446
728x90