[Python] GridSearchCV 를 이용한 하이퍼파라미터 최적값 찾기
2023. 3. 19. 19:22
앞서 포스팅했던 머신러닝 모델들은 모두 하이퍼파라미터 라고 하는 모델의 성능과 특징을 결정하는 값이 존재한다.
모델마다 상이하기 때문에 각 라이브러리 공식문서를 통해 하이퍼파라미터를 확인해서 적용하면 된다.
우선 RandomForestClassifier 모델에 대해 최적의 하이퍼파라미터를 찾아보자.
sk-learn 에서 제공하는 GridSearchCV 함수를 통해 dict 형태로 parameter grid 를 입력해준다.
# GridSearchCV 이용하여 하이퍼파라미터 별 최적값 찾기
param_rfc = {"n_estimators" : [30,60,100,120,150],
"max_depth" : [1,3,5,7,9],
"min_samples_split" : [10,30,50,100,150,200],
"min_samples_leaf" : [10,30,50,100],
"random_state" : [42]
}
'''
n_jobs = -1로 설정 하시면 모든 코어사용이 가능. Default 1.
cv = 교차검증을 위한 fold 횟수.
refit : True 일 경우, 최적 하이퍼파라미터를 찾은 뒤 입력된 estimator 객체를
해당 하이퍼 파라미터로 재학습 시킨다. Default True.
'''
rfc = GridSearchCV(estimator = rfc,
param_grid = param_rfc,
scoring ='accuracy',
cv = StratifiedKFold(n_splits=3, shuffle = True, random_state=42),
refit=True, n_jobs=1, verbose=2)
rfc.fit(x_train, y_train)
print("="*60)
print('RFC 파라미터: ', gscv_rfc.best_params_)
print('RFC 예측 정확도: {:.4f}'.format(gscv_rfc.best_score_))
print("="*60)
여러 하이퍼파라미터에 대해 평가하고자 하는 value 를 입력하여 경우의 수 만큼 (교차검증 추가) 평가하여 accuracy 가 가장 높은 경우를 찾는 방식으로 생각하면 되겠다.
위와 같은 방식으로 찾은 뒤 조금 더 세세하게 (촘촘히) 수치를 조정하여 최적화를 하면 도움이 될 것 같다.
결과:
============================================================
RFC 파라미터: {'max_depth': 7, 'min_samples_leaf': 10,
'min_samples_split': 30, 'n_estimators': 100,
'random_state': 42}
RFC 예측 정확도: 0.8645
============================================================
Accuracy: 0.856
Precision: 0.874
Recall: 0.962
F1-score: 0.916
※ 주의 사항
refit=True와 함께 GridSearchCV를 사용하는 경우 재적합된 모델은 GridSearchCV 객체 자체의 직접적인 속성이 아니라 best_estimator_ 속성으로 GridSearchCV 객체 내부에 저장된다. 따라서 재구성된 RandomForestClassifier 모델의 feature_importances_ 속성에 액세스하려면 best_estimator_를 사용해야 한다.
feature_importances = grid_search.best_estimator_.feature_importances_
print("Feature importances:", feature_importances)
'Python' 카테고리의 다른 글
lightGBM (0) | 2023.07.02 |
---|---|
[Python] 2D IDW Interpolation (2차원 역거리 가중치 보간법) (1) | 2023.06.11 |
[Python] 데이터 전처리 - 스케일링 (0) | 2023.03.18 |
[Python] Random Forest 최적화 방법 (0) | 2023.03.13 |
머신 러닝 알고리즘 : 선형 데이터 (Regression) - 뉴럴네트워크(MLP) (0) | 2023.03.09 |