失敗したモデルの例#

2値分類の例#

telco_churn_train データにおいて、 target_column: gender と設定した場合の性別の2値分類を考えてみよう。実は、これで生成されたモデルはうまくいっていない。そのことをどこで確認できるかを解説していく。

WF の記述#

_export:
  ml:
    input_database: ml_datasets
    output_database: ml_results  

+gluon_train:
  ml_train>:
    notebook: gluon_train
    input_table: ${ml.input_database}.telco_churn_train
    target_column: gender
    model_name: gender_model
    time_limit: 10*60

+gluon_predict:
  ml_predict>:
    notebook: gluon_predict
    model_name: gender_model
    input_table: ${ml.input_database}.telco_churn_test
    output_table: ${ml.output_database}.telco_churn_predicted

gender の分布#

_images/3-11-1-1.png

Fig. 56 トレーニングデータの gender の分布。どちらか一方に偏りはなく、Oversampling は不要であることがわかる。#

Best Model#

_images/3-11-2-1.png

Fig. 57 best_model。eval_metric: roc_auc のため、この指標が1に近い方がよくフィットしたモデルで、0.5 に近いとうまくいっていないモデルとなる。今回は best_model であっても、0.5代の値になってしまっており、うまくいっていないことがうかがえる。#

Predicted Probabilities#

_images/3-11-3-1.png

Fig. 58 予測値(予測確率)の分布。一見綺麗な分布になっていそうだが、X軸に注目。すべての予測値が ±0.5 付近に集中している。予測確率における 0.5 という値は、確信を持ってどちらかに分類できなかったことを意味する値であるので、model がうまくいっていないことが判明する。#

Normarized Confusion Matrix#

_images/3-11-6-1.png

Fig. 59 混同行列。今回は “Male” が Positive サンプルとなっている。#

予測したラベルが Negative (偽)

予測したラベルが Positive (真)

真のラベルが Negative (偽)

TN (True Negative)

FP (False Positive)

真のラベルが Positive (真)

FN (False Negative)

TP (True Positive)

  • TP(0.94): 真のラベルが “Male” であるものを正しく “Male” と予測できた割合。非常に高い割合で正しく予測できていることになる。

  • FP(0.82): 真のラベルが “Female” であるものを誤って “Male” と予測してしまった割合。非常に高い割合で誤った予測をしていることになる。

  • TN(0.18): 真のラベルが “Female” であるものを正しく “Female” と予測できた割合。ほとんど正しく予測できていることになる。

  • FN(0.06): 真のラベルが “Male” であるものを誤って “Female” と予測してしまった割合。このケースはほとんど起きていないことがわかる。

この混同行列からわかることは、

  • “Male” だけの予測に関しては、精度高く予測できているように見える。

  • 一方 “Female” の予測に関してはほとんど間違ってしまっている。

このことから導かれる結論は、このモデルはほとんどのサンプルを “Male” と予測するだけのものになってしまっていることである。

Note

ほとんどのサンプルを常に一方のラベルで予測してしまうモデルは、特にトレーニングデータのラベルに偏りがあるときに生じる。大多数の方の label を常に予測するようなモデルが生成されてしまいがちで、それを回避するために Oversampling を行うのである。

ROC 曲線 と ROC-AUC#

_images/3-11-4-1.png

Fig. 60 ROC 曲線。(左下から右上に向かう)数直線に近いカーブとなってしまっており、故に AUC も 0.5 に近くなってしまっている。”Male” を “Male” と予測することの正答率が高くても、”Male” における ROC は高いとは限らないことに注意しよう。詳しくはROC 曲線と ROC-AUCで説明している。#

Precision-Recall 曲線 と PR_AUC#

_images/3-11-5-1.png

Fig. 61 Precision-Recall 曲線。(左上から右下に向かう)数直線に近いカーブとなってしまっており、故に AUC も 0.5 に近くなってしまっている。これがなぜダメかはPrecision-Recall 曲線 と PR_AUCで説明している。#

結論#

うまくいっていないモデルは、このようにさまざまなモデルの評価指標から判断することができる。gluon_train を実行したあとは、必ずこれらの指標をチェックする習慣をつけておくようにしよう。

多値分類の例#

online_retail_ltv_train データにおいて、 ignore_columns で重要な特徴量を除いた場合の cltv の数量予測(回帰)を考えてみよう。もちろん、これで生成されたモデルはうまくいっていない。

WF の記述#

_export:
  ml:
    input_database: ml_datasets
    output_database: ml_results  

+gluon_train:
  ml_train>:
    notebook: gluon_train
    input_table: ${ml.input_database}.online_retail_ltv_train
    target_column: cltv
    model_name: ltv_model
    ignore_columns: purchase_amount,frequency,recency, avg_basket_value, order_time_gap, avg_backet_size, cnt_returns, has_returned
    time_limit: 10*60    

+gluon_predict:
  ml_predict>:
    notebook: gluon_predict
    model_name: ltv_model
    input_table: ${ml.input_database}.online_retail_ltv_test
    output_table: ${ml.output_database}.online_retail_ltv_predicted

Best Model#

_images/3-11-7-1.png

Fig. 62 best_model。eval_metric: root_mean_squared_error (二乗平均平方根誤差、RMSEと呼ばれる)のため、この指標が0に近い方がよい。本来は必ず0以上の値であるが、leaderboard のために符号が反転されている。(leaderboard ではどんな eval_metrics が設定されても、score_val が大きい方が良いという共通の認識を持たせるために符号を調整している。)#

Distribution of Prediction Results#

_images/3-11-8-1.png

Fig. 63 予測値の分布。予測値の平均と、実際の観測値の平均に大きな乖離があることが判明した。#

Prediction Distribution#

グラフの解釈の仕方は Distributions の該当する項目で解説している。

_images/3-11-9-1.png

Fig. 64 先ほどの状況をさらに明確にするのが観測値vs予測値のプロットである。一見、数直線上にあるプロットが多いように見えるが、実はY軸とX軸のスケールが50倍も違っているのだ。実は、ほとんどのプロットは数直線上から大きく乖離しているのだ。#

Residual Distribution#

グラフの解釈の仕方は Distributions の該当する項目で解説している。

_images/3-11-10-1.png

Fig. 65 Residual Distribution でも、X軸のスケールが大きすぎる(すなわち一部の残差が大きすぎる)ため、正規分布とは程遠い分布となっている。#

Residuals Distribution#

グラフの解釈の仕方は Distributions の該当する項目で解説している。

_images/3-11-11-1.png

Fig. 66 残差のプロットの方でも、±2 に収まっているプロットが少ないことがわかる。#