初心者でも決定木

スポンサーリンク
Python




機械学習のひとつである決定木を使ってみましょう。

【超重要】はじめてのデータ収集方法

環境構築は、【2020年最新】Anacondaのインストール方法

tree = DecisionTreeClassifier()
tree.fit(X, y)

基本的にはこのコードで決定木による学習が完了します。

1行目で決定木を使うと指定しています。

2行目で学習しています。

Xは説明変数(例:客数、商品単価、天気、商品ラインナップ、営業時間)

yは目的変数(例:売上)

たった2行で機械学習ができるんですか

結果を正解率でみてみましょう。

print("Accuracy on training set: {:.3f}".format(tree.score(X, y)))

Accuracy on training set: 1.000

これは、100%の正解率ということですよね

すごい、さすが機械学習

機械学習が強力なのは確かですが、100%はあり得ません。

これは、学習に用いているデータで、結果を評価しているためです。

試験問題を使って、勉強しているようなものです(カンニングですね)

過学習、オーバーフィッティングといわれています。

これを避けるために、一般的には訓練用のデータとテスト用(結果の評価用)のデータを分ける必要があります。

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state = 0)

X(説明変数)をX_train(訓練用)とX_test(テスト用)に分けています。

from sklearn.tree import DecisionTreeClassifier
tree = DecisionTreeClassifier()
tree.fit(X_train, y_train)

この1行目は決定木を使うために必要なコードです。決定木の計算機を使えるようにするイメージですね。

print("Accuracy on training set: {:.3f}".format(tree.score(X_train, y_train)))   # 訓練セットの精度
print("Accuracy on test set: {:.3f}".format(tree.score(X_test, y_test)))         # テストセットの精度

Accuracy on training set: 1.000
Accuracy on test set: 0.974

【基礎】分類モデルの評価方法

こんどは、訓練に使用したデータは正解率100%、訓練に使用していないテストデータは97.4%でした。

やっぱりすごい正解率ですね。

決定木の特徴として、結果が分かりやすいことがあげられます。名前の通り樹上図のような図を描くことができます。

from pydotplus import graph_from_dot_data
from sklearn.tree import export_graphviz
from IPython.display import Image

dot_data = export_graphviz(tree,
                           filled = True,
                           rounded = True,
                           class_names = iris.target_names,
                           feature_names = iris.feature_names,
                           out_file = None)

# 決定木のプロットを出力
graph = graph_from_dot_data(dot_data)
graph.write_png('tree.jpg')  
Image(graph.create_png())


こんな図がでてきます。

英語で分かりにくいですが、ここで、このデータについて少し説明します。

これは、機械学習を学ぶときに一番最初に出会うデータである、アイリスという花の種類とガクの長さ、ガクの幅、花弁の長さ、花弁の幅のデータです。

目的変数

  • setosa
  • versicolor
  • virginica
    上の図の value の位置に対応しています。たとえば、左下の紫色の枠に、value =[0, 0, 3] とあります。これは、viginicaが3つあることということです。valueの下にclassがあり、花の種類に対応しています。

説明変数

  • ガクの長さ(sepal length)
  • ガクの幅(sepal width)
  • 花弁の長さ(petal length)
  • 花弁の幅(petal width)

では、図の説明に入ります。

一番上の枠の一番上に「petal length(cm)<=2.35」とあります。これは花弁の長さが2.35㎝以下かどうかを示しています。以下ならTrueの方向へ、そうでないならFalseの方へ進みます。それを分岐がなくなるまで繰り返します。

数値で分けているので、非常に理解しやすいですね。また、新たなデータで予測したいときに、判断しやすいです。

以下の図では、視覚的にも分かりやすくできます。

from dtreeviz.trees import dtreeviz

viz = dtreeviz(tree,
               X_train,
               y_train,
               target_name='variety',
               feature_names=iris.feature_names,
               class_names=iris. target_names.tolist(),
               orientation='LR'
               )
display(viz)

画質が悪いですが、さきほどの図と同じ内容です。

分岐点がグラフで、x軸にある▲が分岐ポイントです。

今回提示したコードだけでは、今回の結果は出せません、

from sklearn.datasets import load_iris
iris = load_iris()
import warnings
warnings.filterwarnings('ignore')
X = iris.data
y = iris.target

最初にこのコードを書く必要があります。

非常に分かりやすく、今後の判断材料になる機械学習の方法だと思います。



コメント

タイトルとURLをコピーしました