Model inspection#

先ほどの Training step で学習されたモデル群についての検証を行うステップである。

Best Model#

Note

AutoGluon のバグのため、それを考慮するため Best Model を2回導出する流れとなっている。最終的に選定された Best Model を追っていく際には、以下のコード以下の部分を読み進めよう。

```python
print("Best model (before): ", predictor.get_model_best())
# Workaround: models_to_keep='best' has a bug that invalid predictor._trainer.model_best is being used
# https://github.com/awslabs/autogluon/blob/0.3.1/tabular/src/autogluon/tabular/predictor/predictor.py#L1933
predictor.delete_models(models_to_keep=predictor.get_model_best(), dry_run=False)
predictor.save_space()
best_model = predictor.get_model_best()
print("Best model (after delete models): ", best_model)

# ATML-114: Workaround for AttributeError: 'NoneType' object has no attribute 'name'
# on https://github.com/autogluon/autogluon/blob/0.3.1/tabular/src/autogluon/tabular/trainer/abstract_trainer.py#L1564
if predictor._trainer.model_best is None:
    predictor._trainer.model_best = best_model

model_bucket = f"ml-model-{td_site}-{td_stage}"
model_object_key = f"{td_account_id}/{td_user_id}/{model_name}"
```

複数のモデル候補の中で、Best Model が何であるかを出力。今回はアンサンブルモデルである WeightedEnsemble_L2 となっている。

Best model (before):  WeightedEnsemble_L2
WARNING: Deleting model LightGBMXT_BAG_L1/T0. All files under ag_models/models/LightGBMXT_BAG_L1/T0/ will be removed.
WARNING: Deleting model RandomForestGini_BAG_L1/T0. All files under ag_models/models/RandomForestGini_BAG_L1/T0/ will be removed.
...中略...
WARNING: Deleting model ExtraTreesEntr_BAG_L1/T0. All files under ag_models/models/ExtraTreesEntr_BAG_L1/T0/ will be removed.
Best model (after delete models):  WeightedEnsemble_L2
_images/3-8-6-1.png

Fig. 24 Best model(オレンジ色)を可視化。#

十分な時間の time_limit が設定されていると、多層スタックアンサンブルが行われる。以下の best_model は3層(L3)のアンサンブルが実施された例である。

_images/3-8-18-1.png

Fig. 25 3層のアンサンブルが行われ、WeightedEnsemble_L3 モデルが選ばれた例。#

leaderboard#

学習の過程で考慮されたモデル一覧が出力される。この情報はオプション設定で leaderboard テーブルとして出力することもできる。

leaderboard=predictor.leaderboard(extra_info=True, silent=True)
if len(leaderboard) > 0 and leaderboard['num_features'].max() > 100:
    leaderboard.drop('features', axis=1, inplace=True) # avoid showing too many verbose features
leaderboard.style.apply(highlight_col, axis=None)

#papermill_description=Store prediction models to s3

model

score_val

pred_time_val

fit_time

pred_time_val_marginal

fit_time_marginal

stack_level

can_infer

fit_order

num_features

num_models

num_models_w_ancestors

memory_size

memory_size_w_ancestors

memory_size_min

memory_size_min_w_ancestors

num_ancestors

num_descendants

model_type

child_model_type

hyperparameters

hyperparameters_fit

ag_args_fit

features

child_hyperparameters

child_hyperparameters_fit

child_ag_args_fit

ancestors

descendants

0

WeightedEnsemble_L2

0.847262

6.62134

2135.623987

0.001967

2.658805

2

TRUE

10

9

1

50

5360

33893194

5360

4224773

9

0

WeightedEnsembleModel

GreedyWeightedEnsembleModel

{‘use_orig_features’: False, ‘max_base_models’: 25, ‘max_base_models_per_type’: 5, ‘save_bag_folds’: True}

{}

{‘max_memory_usage_ratio’: 1.0, ‘max_time_limit_ratio’: 1.0, ‘max_time_limit’: None, ‘min_time_limit’: 0, ‘ignored_type_group_special’: None, ‘ignored_type_group_raw’: None, ‘get_features_kwargs’: None, ‘get_features_kwargs_extra’: None, ‘drop_unique’: False}

[‘NeuralNetMXNet_BAG_L1/T0’, ‘NeuralNetFastAI_BAG_L1/T0’, ‘XGBoost_BAG_L1/T0’, ‘CatBoost_BAG_L1/T5’, ‘LightGBMLarge_BAG_L1’, ‘LightGBM_BAG_L1/T0’, ‘RandomForestEntr_BAG_L1/T0’, ‘CatBoost_BAG_L1/T6’, ‘CatBoost_BAG_L1/T0’]

{‘ensemble_size’: 100}

{‘ensemble_size’: 76}

{‘max_memory_usage_ratio’: 1.0, ‘max_time_limit_ratio’: 1.0, ‘max_time_limit’: None, ‘min_time_limit’: 0, ‘ignored_type_group_special’: None, ‘ignored_type_group_raw’: None, ‘get_features_kwargs’: None, ‘get_features_kwargs_extra’: None, ‘drop_unique’: False}

[‘NeuralNetMXNet_BAG_L1/T0’, ‘NeuralNetFastAI_BAG_L1/T0’, ‘XGBoost_BAG_L1/T0’, ‘CatBoost_BAG_L1/T5’, ‘LightGBMLarge_BAG_L1’, ‘LightGBM_BAG_L1/T0’, ‘RandomForestEntr_BAG_L1/T0’, ‘CatBoost_BAG_L1/T6’, ‘CatBoost_BAG_L1/T0’]

[]

1

CatBoost_BAG_L1/T5

0.846937

0.106822

13.246835

0.106822

13.246835

1

TRUE

4

19

6

6

2213963

2213963

446708

446708

0

1

StackerEnsembleModel

CatBoostModel

{‘use_orig_features’: True, ‘max_base_models’: 25, ‘max_base_models_per_type’: 5, ‘save_bag_folds’: True}

{}

{‘max_memory_usage_ratio’: 1.0, ‘max_time_limit_ratio’: 1.0, ‘max_time_limit’: None, ‘min_time_limit’: 0, ‘ignored_type_group_special’: None, ‘ignored_type_group_raw’: None, ‘get_features_kwargs’: None, ‘get_features_kwargs_extra’: None, ‘drop_unique’: False}

[‘streamingtv’, ‘techsupport’, ‘seniorcitizen’, ‘streamingmovies’, ‘gender’, ‘paymentmethod’, ‘totalcharges’, ‘onlinebackup’, ‘internetservice’, ‘paperlessbilling’, ‘multiplelines’, ‘contract’, ‘phoneservice’, ‘partner’, ‘dependents’, ‘onlinesecurity’, ‘deviceprotection’, ‘monthlycharges’, ‘tenure’]

{‘iterations’: 10000, ‘learning_rate’: 0.07728803796449603, ‘random_seed’: 0, ‘allow_writing_files’: False, ‘eval_metric’: ‘Logloss’, ‘depth’: 6, ‘l2_leaf_reg’: 1.308698577584075}

{‘iterations’: 197}

{‘max_memory_usage_ratio’: 1.0, ‘max_time_limit_ratio’: 1.0, ‘max_time_limit’: None, ‘min_time_limit’: 0, ‘ignored_type_group_special’: None, ‘ignored_type_group_raw’: [‘object’], ‘get_features_kwargs’: None, ‘get_features_kwargs_extra’: None}

[]

[‘WeightedEnsemble_L2’]

2

CatBoost_BAG_L1/T6

0.846671

0.111467

12.598098

0.111467

12.598098

1

TRUE

5

19

6

6

1983823

1983823

446708

446708

0

1

StackerEnsembleModel

CatBoostModel

{‘use_orig_features’: True, ‘max_base_models’: 25, ‘max_base_models_per_type’: 5, ‘save_bag_folds’: True}

{}

{‘max_memory_usage_ratio’: 1.0, ‘max_time_limit_ratio’: 1.0, ‘max_time_limit’: None, ‘min_time_limit’: 0, ‘ignored_type_group_special’: None, ‘ignored_type_group_raw’: None, ‘get_features_kwargs’: None, ‘get_features_kwargs_extra’: None, ‘drop_unique’: False}

[‘streamingtv’, ‘techsupport’, ‘seniorcitizen’, ‘streamingmovies’, ‘gender’, ‘paymentmethod’, ‘totalcharges’, ‘onlinebackup’, ‘internetservice’, ‘paperlessbilling’, ‘multiplelines’, ‘contract’, ‘phoneservice’, ‘partner’, ‘dependents’, ‘onlinesecurity’, ‘deviceprotection’, ‘monthlycharges’, ‘tenure’]

{‘iterations’: 10000, ‘learning_rate’: 0.08077624862012926, ‘random_seed’: 0, ‘allow_writing_files’: False, ‘eval_metric’: ‘Logloss’, ‘depth’: 5, ‘l2_leaf_reg’: 2.0224538666853173}

{‘iterations’: 185}

{‘max_memory_usage_ratio’: 1.0, ‘max_time_limit_ratio’: 1.0, ‘max_time_limit’: None, ‘min_time_limit’: 0, ‘ignored_type_group_special’: None, ‘ignored_type_group_raw’: [‘object’], ‘get_features_kwargs’: None, ‘get_features_kwargs_extra’: None}

[]

[‘WeightedEnsemble_L2’]

3

CatBoost_BAG_L1/T0

0.846469

0.112227

13.08717

0.112227

13.08717

1

TRUE

3

19

6

6

2102763

2102763

446708

446708

0

1

StackerEnsembleModel

CatBoostModel

{‘use_orig_features’: True, ‘max_base_models’: 25, ‘max_base_models_per_type’: 5, ‘save_bag_folds’: True}

{}

{‘max_memory_usage_ratio’: 1.0, ‘max_time_limit_ratio’: 1.0, ‘max_time_limit’: None, ‘min_time_limit’: 0, ‘ignored_type_group_special’: None, ‘ignored_type_group_raw’: None, ‘get_features_kwargs’: None, ‘get_features_kwargs_extra’: None, ‘drop_unique’: False}

[‘streamingtv’, ‘techsupport’, ‘seniorcitizen’, ‘streamingmovies’, ‘gender’, ‘paymentmethod’, ‘totalcharges’, ‘onlinebackup’, ‘internetservice’, ‘paperlessbilling’, ‘multiplelines’, ‘contract’, ‘phoneservice’, ‘partner’, ‘dependents’, ‘onlinesecurity’, ‘deviceprotection’, ‘monthlycharges’, ‘tenure’]

{‘iterations’: 10000, ‘learning_rate’: 0.05, ‘random_seed’: 0, ‘allow_writing_files’: False, ‘eval_metric’: ‘Logloss’, ‘depth’: 6, ‘l2_leaf_reg’: 3.0}

{‘iterations’: 192}

{‘max_memory_usage_ratio’: 1.0, ‘max_time_limit_ratio’: 1.0, ‘max_time_limit’: None, ‘min_time_limit’: 0, ‘ignored_type_group_special’: None, ‘ignored_type_group_raw’: [‘object’], ‘get_features_kwargs’: None, ‘get_features_kwargs_extra’: None}

[]

[‘WeightedEnsemble_L2’]

leaderboard with extra metrics#

分類の場合#

モデルは eval_metric で指定された指標によって評価されるが、いくつかの他の metric で評価したものを leaderboard に付与した結果テーブルを出力してくれる。

# Recompute leaderboard with extra metrics
print(f"predictor.eval_metric={predictor.eval_metric.name}")
extra_metrics = get_extra_metrics(predictor)
print(f"extra_metrics: {extra_metrics}")

leaderboard = predictor.leaderboard(train_data, extra_metrics=extra_metrics, silent=True)
flip_negative_scores(leaderboard, predictor)
leaderboard.style.apply(highlight_col, axis=None)
predictor.eval_metric=roc_auc
extra_metrics: ['roc_auc', 'log_loss', 'balanced_accuracy', 'f1', 'accuracy']

model

score_test

roc_auc

log_loss

balanced_accuracy

f1

accuracy

score_val

pred_time_test

pred_time_val

fit_time

pred_time_test_marginal

pred_time_val_marginal

fit_time_marginal

stack_level

can_infer

fit_order

0

LightGBMLarge_BAG_L1

0.969999

0.969999

0.272964

0.838886

0.791595

0.90142

0.828583

0.139757

0.090179

228.192318

0.139757

0.090179

228.192318

1

TRUE

20

1

ExtraTreesGini_BAG_L1/T0

0.918672

0.918672

0.345563

0.760022

0.663194

0.842596

0.836334

0.134413

0.203219

0.645045

0.134413

0.203219

0.645045

1

TRUE

15

2

ExtraTreesEntr_BAG_L1/T0

0.91664

0.91664

0.347596

0.76187

0.665516

0.842799

0.835924

0.119714

0.210507

0.656981

0.119714

0.210507

0.656981

1

TRUE

16

3

XGBoost_BAG_L1/T0

0.903982

0.903982

0.366327

0.760819

0.661846

0.838742

0.831362

0.221817

0.234169

56.309953

0.221817

0.234169

56.309953

1

TRUE

18

4

LightGBM_BAG_L1/T0

0.901289

0.901289

0.37698

0.740124

0.635987

0.839554

0.821932

0.100931

0.203394

42.258624

0.100931

0.203394

42.258624

1

TRUE

2

5

RandomForestGini_BAG_L1/T0

0.888764

0.888764

0.375412

0.733232

0.62109

0.827992

0.841808

0.110949

0.18576

0.618246

0.110949

0.18576

0.618246

1

TRUE

3

回帰の場合#

モデルは eval_metric で指定された指標によって評価されるが、いくつかの他の metric で評価したものを leaderboard に付与した結果テーブルを出力してくれる。

_images/3-8-35-1.png

Fig. 26 回帰の場合、Best Model の extra_metrics の値を可視化してくれる。#

predictor.eval_metric=root_mean_squared_error
extra_metrics: ['mean_squared_error', 'root_mean_squared_error', mean_absolute_percentage_error, 'r2', 'mean_absolute_error']

model

score_test

mean_squared_error

root_mean_squared_error

mean_absolute_percentage_error

r2

mean_absolute_error

score_val

pred_time_test

pred_time_val

fit_time

pred_time_test_marginal

pred_time_val_marginal

fit_time_marginal

stack_level

can_infer

fit_order

0

RandomForestMSE_BAG_L1/T0

-0.552165

0.304887

0.552165

0.069127

0.818738

0.45023

-0.668878

0.074119

0.112338

0.462881

0.074119

0.112338

0.462881

1

TRUE

3

1

XGBoost_BAG_L2/T0

-0.566566

0.320997

0.566566

0.070469

0.80916

0.453708

-0.671653

6.546928

2.837787

396.600877

0.087601

0.026135

9.069155

2

TRUE

29

2

RandomForestMSE_BAG_L2/T0

-0.578908

0.335134

0.578908

0.071518

0.800755

0.464212

-0.662229

6.533948

2.927716

388.252239

0.07462

0.116065

0.720517

2

TRUE

20

3

WeightedEnsemble_L3

-0.579636

0.335978

0.579636

0.072412

0.800254

0.467344

-0.652238

8.427499

3.818912

500.734541

0.003424

0.00044

0.382673

3

TRUE

32

4

CatBoost_BAG_L2/T1

-0.581173

0.337762

0.581173

0.072397

0.799193

0.468304

-0.657705

6.486438

2.836082

391.666523

0.027111

0.02443

4.134801

2

TRUE

22

5

ExtraTreesMSE_BAG_L2/T0

-0.582704

0.339545

0.582704

0.072091

0.798133

0.466932

-0.657569

6.542868

2.935313

388.00333

0.083541

0.123661

0.471608

2

TRUE

27