Rで機械学習:分類編

Rで機械学習による分類を行ういくつかの方法を説明します。

土井 翔平 (国立情報学研究所)
2019-05-27

Table of Contents


はじめに

機械学習の分野では(多項)ロジットのように、クラスを予測することを分類(classification)と呼びます。 一方で、最小二乗法のように連続値を予測することを回帰(regression)と呼びます。

今回は代表的な機械学習の分類手法である

  1. ロジスティック回帰
  2. 決定木
  3. ランダムフォレスト
  4. サポート・ベクター・マシン (SVM)
  5. ニューラルネット

について、Rで行う方法を紹介します。

必要なパッケージの読み込み

Rで機械学習を行う便利なパッケージとしてcaretmlrがあります(らしいです)。 今回はcaretを使ってみます。


library(tidyverse)

Registered S3 methods overwritten by 'ggplot2':
  method         from 
  [.quosures     rlang
  c.quosures     rlang
  print.quosures rlang

Registered S3 method overwritten by 'rvest':
  method            from
  read_xml.response xml2

── Attaching packages ──────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse 1.2.1 ──

✔ ggplot2 3.1.1       ✔ purrr   0.3.2  
✔ tibble  2.1.1       ✔ dplyr   0.8.0.1
✔ tidyr   0.8.3       ✔ stringr 1.4.0  
✔ readr   1.3.1       ✔ forcats 0.4.0  

── Conflicts ─────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()

library(caret)

Loading required package: lattice

Attaching package: 'caret'

The following object is masked from 'package:purrr':

    lift

必要なデータの読み込み

例のごとく、東大・朝日共同調査の2014年衆院選・2016年参院選世論調査のデータを使います。


data <- read_csv("http://www.masaki.j.u-tokyo.ac.jp/utas/2014_2016UTASV20161004.csv", 
                 locale = locale(encoding = "shift-jis"), na = c("66", "99", "999"))

Parsed with column specification:
cols(
  .default = col_character(),
  ID = col_double(),
  PREFEC = col_double(),
  HRDIST = col_double(),
  W1Q1 = col_double(),
  W1Q2 = col_double(),
  W1Q3 = col_double(),
  W1Q4 = col_double(),
  W1Q5_1 = col_double(),
  W1Q5_2 = col_double(),
  W1Q5_3 = col_double(),
  W1Q6 = col_double(),
  W1Q7 = col_double(),
  W1Q8 = col_double(),
  W1Q9 = col_double(),
  W1Q10 = col_double(),
  W1Q11 = col_double(),
  W1Q12 = col_double(),
  W1Q13 = col_double(),
  W1Q14_1 = col_double(),
  W1Q14_2_1 = col_double()
  # ... with 56 more columns
)

See spec(...) for full column specifications.

Warning: 8 parsing failures.
 row    col expected actual                                                                file
1810 PREFEC a double     -- 'http://www.masaki.j.u-tokyo.ac.jp/utas/2014_2016UTASV20161004.csv'
1810 HRDIST a double     -- 'http://www.masaki.j.u-tokyo.ac.jp/utas/2014_2016UTASV20161004.csv'
1811 PREFEC a double     -- 'http://www.masaki.j.u-tokyo.ac.jp/utas/2014_2016UTASV20161004.csv'
1811 HRDIST a double     -- 'http://www.masaki.j.u-tokyo.ac.jp/utas/2014_2016UTASV20161004.csv'
1812 PREFEC a double     -- 'http://www.masaki.j.u-tokyo.ac.jp/utas/2014_2016UTASV20161004.csv'
.... ...... ........ ...... ...................................................................
See problems(...) for more details.

下ごしらえ

データセットの内、分析に使う変数を抜き出しておきます。 更に、数値ではなくカテゴリカルデータである

についてはas.factor()でカテゴリカルデータ(因子型)に変形しておきます。


data <- data %>% 
  select(vote = W1Q1,
         party = W1Q2,
         sex = W1F1,
         age = W1F2,
         educ = W1F3,
         job = W1F4,
         W1Q7, W1Q8, W1Q9, W1Q10, W1Q11, W1Q12, W1Q13, W1Q14_1,
         W1Q16_1, W1Q16_2, W1Q16_3, W1Q16_4, W1Q16_5, W1Q16_6, 
         W1Q16_7, W1Q16_8, W1Q16_9, W1Q16_10, W1Q16_11, 
         W1Q16_12, W1Q16_13, W1Q16_14, W1Q16_15, W1Q16_16, W1Q16_17,
         W1Q19_1)%>% 
  mutate(vote = vote - 1,
         vote = as.factor(vote),
         party = as.factor(party),
         sex = as.factor(sex),
         educ = as.factor(educ),
         job = as.factor(job))

さらに、投票するかどうかを予測するデータセットと


data_vote <- data %>% 
  select(-party) %>%
  rename(target = vote) %>% 
  drop_na()

投票先を予測するデータセットをそれぞれ分けておきます。


data_party <- data %>% 
  select(-vote) %>%
  rename(target = party) %>% 
  drop_na()

シード値の設定

いくつかの分析手法では乱数を用います。 その名の通り乱数は毎回違う値が出てくるので、分析結果も変わってきます。

そこで、乱数を発生させるときの基準となるシード値を設定することで、毎回同じ乱数が発生するようにします。


set.seed(334)

機械学習の分類*

余談までに、機械学習全般の代表的な手法を概説します。

ちなみに、分類や回帰のように目的変数という正解が存在して、それを予測することを教師付き(supervised)学習と呼びます。 計量経済学や統計的因果推論は予測が主目的ではないですが、応答変数があるという意味で教師付き学習でと手法を共有しています。

逆に正解が存在せず、機械にデータから見えない構造を抽出させることを教師なし(unsupervised)学習と言います。 よくあるのはk-meansや混合ガウス分布(Gaussian mixture)、潜在ディリクレ配分(latent dirichlet allocation)によるクラスタリングです。

強化学習(reinforcement learning)とは正解は無いものの、目的は存在し、その目的を実現するための行動を学習するものです。 例えば、AlphaGo Zeroでは実際の棋譜を正解として使うのではなく、PC同士に何十日も戦わせて勝利という目的に近づく行動を学習しました。

ロジスティック回帰

ロジスティック回帰の方法は以前、解説をしましたが、比較のためにもう一度やります。

caretではtrain()の中にformula、使用するデータ、分析手法を入力します。 まずは、年齢だけで予測してみます。


vote_logit <- train(
  target ~ age,
  data = data_vote,
  method = "glm",
  family = binomial()
)

その後、predict()に分析したモデルと予測したいデータセットを入れて予測を行います。 予測結果と実際の値(今回はtarget)との比較をconfusionMatrix()で行います。


confusionMatrix(predict(vote_logit, data_vote), data_vote$target)

Confusion Matrix and Statistics

          Reference
Prediction    0    1
         0    0    0
         1  389 1074
                                          
               Accuracy : 0.7341          
                 95% CI : (0.7107, 0.7566)
    No Information Rate : 0.7341          
    P-Value [Acc > NIR] : 0.5136          
                                          
                  Kappa : 0               
                                          
 Mcnemar's Test P-Value : <2e-16          
                                          
            Sensitivity : 0.0000          
            Specificity : 1.0000          
         Pos Pred Value :    NaN          
         Neg Pred Value : 0.7341          
             Prevalence : 0.2659          
         Detection Rate : 0.0000          
   Detection Prevalence : 0.0000          
      Balanced Accuracy : 0.5000          
                                          
       'Positive' Class : 0               
                                          

いろいろ出てきますが、Accuracyを見ると72%ほどの予測精度があることが分かります。

多項ロジットもほぼ同様に行います。


party_logit <- train(
  target ~ age,
  data = data_party,
  method = "multinom",
  trace = FALSE
)

confusionMatrix(predict(party_logit, data_party), data_party$target)

Confusion Matrix and Statistics

          Reference
Prediction   1   2   3   4   5   6   7   8   9  90
        1  363 204 169 131  21 125  17  16   5  17
        2    0   0   0   0   0   0   0   0   0   0
        3    0   0   0   0   0   0   0   0   0   0
        4    0   0   0   0   0   0   0   0   0   0
        5    0   0   0   0   0   0   0   0   0   0
        6    0   0   0   0   0   0   0   0   0   0
        7    0   0   0   0   0   0   0   0   0   0
        8    0   0   0   0   0   0   0   0   0   0
        9    0   0   0   0   0   0   0   0   0   0
        90   0   0   0   0   0   0   0   0   0   0

Overall Statistics
                                          
               Accuracy : 0.3399          
                 95% CI : (0.3115, 0.3692)
    No Information Rate : 0.3399          
    P-Value [Acc > NIR] : 0.5115          
                                          
                  Kappa : 0               
                                          
 Mcnemar's Test P-Value : NA              

Statistics by Class:

                     Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
Sensitivity            1.0000    0.000   0.0000   0.0000  0.00000    0.000  0.00000  0.00000
Specificity            0.0000    1.000   1.0000   1.0000  1.00000    1.000  1.00000  1.00000
Pos Pred Value         0.3399      NaN      NaN      NaN      NaN      NaN      NaN      NaN
Neg Pred Value            NaN    0.809   0.8418   0.8773  0.98034    0.883  0.98408  0.98502
Prevalence             0.3399    0.191   0.1582   0.1227  0.01966    0.117  0.01592  0.01498
Detection Rate         0.3399    0.000   0.0000   0.0000  0.00000    0.000  0.00000  0.00000
Detection Prevalence   1.0000    0.000   0.0000   0.0000  0.00000    0.000  0.00000  0.00000
Balanced Accuracy      0.5000    0.500   0.5000   0.5000  0.50000    0.500  0.50000  0.50000
                     Class: 9 Class: 90
Sensitivity          0.000000   0.00000
Specificity          1.000000   1.00000
Pos Pred Value            NaN       NaN
Neg Pred Value       0.995318   0.98408
Prevalence           0.004682   0.01592
Detection Rate       0.000000   0.00000
Detection Prevalence 0.000000   0.00000
Balanced Accuracy    0.500000   0.50000

やはり、前回とほぼ同様の32%の予測精度があります。

予測精度のからくり

ところで、前回の記事で、これには「からくり」があると言いましたが、それは各分析結果の最初の表を見ると分かります。

表では縦軸に予測結果、横軸に実際の値があります。 それぞれのセルの中には該当するサンプルサイズが表示されています。 つまり、対角線上にある数が正解数を意味しています。

しかし、投票するかどうかでは全て1(=投票に行く)と予測しており、投票先でも全て1(=自民党に投票する)と予測しています。 そして、実際に投票に行った人が7割、自民党に投票した人が3割いるので、先程の予測精度になったということでした。

つまり、単純に正答率を予測精度として見る場合、その下限は0%ではないということです。

特徴量を増やす

予測精度を高めるシンプルな方法は特徴量を増やすことです。 年齢に加えて性別、学歴、職業、現状や政策に関する意見(問7から14,16(1)から(17)に対する答え)を入れて分析してみます。


vote_logit <- train(
  target ~ .,
  data = data_vote,
  method = "glm",
  family = binomial()
)

confusionMatrix(predict(vote_logit, data_vote), data_vote$target)

Confusion Matrix and Statistics

          Reference
Prediction    0    1
         0  107   66
         1  282 1008
                                          
               Accuracy : 0.7621          
                 95% CI : (0.7395, 0.7837)
    No Information Rate : 0.7341          
    P-Value [Acc > NIR] : 0.007762        
                                          
                  Kappa : 0.2596          
                                          
 Mcnemar's Test P-Value : < 2.2e-16       
                                          
            Sensitivity : 0.27506         
            Specificity : 0.93855         
         Pos Pred Value : 0.61850         
         Neg Pred Value : 0.78140         
             Prevalence : 0.26589         
         Detection Rate : 0.07314         
   Detection Prevalence : 0.11825         
      Balanced Accuracy : 0.60681         
                                          
       'Positive' Class : 0               
                                          

party_logit <- train(
  target ~ .,
  data = data_party,
  method = "multinom",
  trace = FALSE
)

confusionMatrix(predict(party_logit, data_party), data_party$target)

Confusion Matrix and Statistics

          Reference
Prediction   1   2   3   4   5   6   7   8   9  90
        1  309  42  55  79  10  25   3   3   2   8
        2   14  95  39   8   1  40  10   9   1   1
        3   21  24  50   8   0  12   3   0   0   1
        4   13  11   9  30   0   2   0   0   0   1
        5    0   0   0   0   9   0   0   0   0   0
        6    6  30  16   6   1  45   1   3   0   3
        7    0   2   0   0   0   0   0   0   0   0
        8    0   0   0   0   0   0   0   1   0   0
        9    0   0   0   0   0   0   0   0   2   0
        90   0   0   0   0   0   1   0   0   0   3

Overall Statistics
                                          
               Accuracy : 0.5094          
                 95% CI : (0.4789, 0.5398)
    No Information Rate : 0.3399          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.3482          
                                          
 Mcnemar's Test P-Value : NA              

Statistics by Class:

                     Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7  Class: 8
Sensitivity            0.8512  0.46569  0.29586  0.22901 0.428571  0.36000 0.000000 0.0625000
Specificity            0.6780  0.85764  0.92325  0.96158 1.000000  0.93001 0.998097 1.0000000
Pos Pred Value         0.5765  0.43578  0.42017  0.45455 1.000000  0.40541 0.000000 1.0000000
Neg Pred Value         0.8985  0.87176  0.87460  0.89920 0.988669  0.91641 0.984053 0.9859419
Prevalence             0.3399  0.19101  0.15824  0.12266 0.019663  0.11704 0.015918 0.0149813
Detection Rate         0.2893  0.08895  0.04682  0.02809 0.008427  0.04213 0.000000 0.0009363
Detection Prevalence   0.5019  0.20412  0.11142  0.06180 0.008427  0.10393 0.001873 0.0009363
Balanced Accuracy      0.7646  0.66166  0.60955  0.59529 0.714286  0.64501 0.499049 0.5312500
                     Class: 9 Class: 90
Sensitivity          0.400000  0.176471
Specificity          1.000000  0.999049
Pos Pred Value       1.000000  0.750000
Neg Pred Value       0.997186  0.986842
Prevalence           0.004682  0.015918
Detection Rate       0.001873  0.002809
Detection Prevalence 0.001873  0.003745
Balanced Accuracy    0.700000  0.587760

投票行動については若干の、投票先については大幅な改善が見られました。

決定木

もう一つのアプローチは分析手法を変えることです。 caretの便利な点は様々な分類器をまとめて利用することができる点です。

まずは、シンプルな分類器である決定木(decision tree)について見てみます。 決定木とは以下の図(Wikipediaの決定木のページより)のような決定木を作成します。 それぞれのノードでは変数の値によって分岐し、効率的に目的変数を分類していくことが目的です。

決定木
決定木

caretでは以下のように行います。


vote_tree <- train(
  target ~ .,
  data = data_vote,
  method = "rpart"
)

confusionMatrix(predict(vote_tree, data_vote), data_vote$target)

Confusion Matrix and Statistics

          Reference
Prediction    0    1
         0    0    0
         1  389 1074
                                          
               Accuracy : 0.7341          
                 95% CI : (0.7107, 0.7566)
    No Information Rate : 0.7341          
    P-Value [Acc > NIR] : 0.5136          
                                          
                  Kappa : 0               
                                          
 Mcnemar's Test P-Value : <2e-16          
                                          
            Sensitivity : 0.0000          
            Specificity : 1.0000          
         Pos Pred Value :    NaN          
         Neg Pred Value : 0.7341          
             Prevalence : 0.2659          
         Detection Rate : 0.0000          
   Detection Prevalence : 0.0000          
      Balanced Accuracy : 0.5000          
                                          
       'Positive' Class : 0               
                                          

party_tree <- train(
  target ~ .,
  data = data_party,
  method = "rpart"
)

confusionMatrix(predict(party_tree, data_party), data_party$target)

Confusion Matrix and Statistics

          Reference
Prediction   1   2   3   4   5   6   7   8   9  90
        1  313  65  98 100  19  36   6   2   5   8
        2   50 139  71  31   2  89  11  14   0   9
        3    0   0   0   0   0   0   0   0   0   0
        4    0   0   0   0   0   0   0   0   0   0
        5    0   0   0   0   0   0   0   0   0   0
        6    0   0   0   0   0   0   0   0   0   0
        7    0   0   0   0   0   0   0   0   0   0
        8    0   0   0   0   0   0   0   0   0   0
        9    0   0   0   0   0   0   0   0   0   0
        90   0   0   0   0   0   0   0   0   0   0

Overall Statistics
                                          
               Accuracy : 0.4232          
                 95% CI : (0.3934, 0.4535)
    No Information Rate : 0.3399          
    P-Value [Acc > NIR] : 9.235e-09       
                                          
                  Kappa : 0.1968          
                                          
 Mcnemar's Test P-Value : NA              

Statistics by Class:

                     Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
Sensitivity            0.8623   0.6814   0.0000   0.0000  0.00000    0.000  0.00000  0.00000
Specificity            0.5191   0.6794   1.0000   1.0000  1.00000    1.000  1.00000  1.00000
Pos Pred Value         0.4801   0.3341      NaN      NaN      NaN      NaN      NaN      NaN
Neg Pred Value         0.8798   0.9003   0.8418   0.8773  0.98034    0.883  0.98408  0.98502
Prevalence             0.3399   0.1910   0.1582   0.1227  0.01966    0.117  0.01592  0.01498
Detection Rate         0.2931   0.1301   0.0000   0.0000  0.00000    0.000  0.00000  0.00000
Detection Prevalence   0.6105   0.3895   0.0000   0.0000  0.00000    0.000  0.00000  0.00000
Balanced Accuracy      0.6907   0.6804   0.5000   0.5000  0.50000    0.500  0.50000  0.50000
                     Class: 9 Class: 90
Sensitivity          0.000000   0.00000
Specificity          1.000000   1.00000
Pos Pred Value            NaN       NaN
Neg Pred Value       0.995318   0.98408
Prevalence           0.004682   0.01592
Detection Rate       0.000000   0.00000
Detection Prevalence 0.000000   0.00000
Balanced Accuracy    0.500000   0.50000

どちらもロジットよりも精度は高くないようです。

ランダムフォレスト

ランダムフォレストとは決定木を複数作り(つまり、森を作り)、その多数決で最終的な予測を行う分類器です。


vote_rf <- train(
  target ~ .,
  data = data_vote,
  method = "rf"
)

confusionMatrix(predict(vote_rf, data_vote), data_vote$target)

Confusion Matrix and Statistics

          Reference
Prediction    0    1
         0  385    0
         1    4 1074
                                         
               Accuracy : 0.9973         
                 95% CI : (0.993, 0.9993)
    No Information Rate : 0.7341         
    P-Value [Acc > NIR] : <2e-16         
                                         
                  Kappa : 0.993          
                                         
 Mcnemar's Test P-Value : 0.1336         
                                         
            Sensitivity : 0.9897         
            Specificity : 1.0000         
         Pos Pred Value : 1.0000         
         Neg Pred Value : 0.9963         
             Prevalence : 0.2659         
         Detection Rate : 0.2632         
   Detection Prevalence : 0.2632         
      Balanced Accuracy : 0.9949         
                                         
       'Positive' Class : 0              
                                         

party_rf <- train(
  target ~ .,
  data = data_party,
  method = "rf"
)

confusionMatrix(predict(party_rf, data_party), data_party$target)

Confusion Matrix and Statistics

          Reference
Prediction   1   2   3   4   5   6   7   8   9  90
        1  363   0   0   0   0   0   0   0   0   0
        2    0 204   0   0   0   0   0   0   0   0
        3    0   0 169   0   0   0   0   0   0   0
        4    0   0   0 131   0   0   0   0   0   0
        5    0   0   0   0  21   0   0   0   0   0
        6    0   0   0   0   0 125   0   1   0   0
        7    0   0   0   0   0   0  17   0   0   0
        8    0   0   0   0   0   0   0  15   0   0
        9    0   0   0   0   0   0   0   0   5   0
        90   0   0   0   0   0   0   0   0   0  17

Overall Statistics
                                     
               Accuracy : 0.9991     
                 95% CI : (0.9948, 1)
    No Information Rate : 0.3399     
    P-Value [Acc > NIR] : < 2.2e-16  
                                     
                  Kappa : 0.9988     
                                     
 Mcnemar's Test P-Value : NA         

Statistics by Class:

                     Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
Sensitivity            1.0000    1.000   1.0000   1.0000  1.00000   1.0000  1.00000  0.93750
Specificity            1.0000    1.000   1.0000   1.0000  1.00000   0.9989  1.00000  1.00000
Pos Pred Value         1.0000    1.000   1.0000   1.0000  1.00000   0.9921  1.00000  1.00000
Neg Pred Value         1.0000    1.000   1.0000   1.0000  1.00000   1.0000  1.00000  0.99905
Prevalence             0.3399    0.191   0.1582   0.1227  0.01966   0.1170  0.01592  0.01498
Detection Rate         0.3399    0.191   0.1582   0.1227  0.01966   0.1170  0.01592  0.01404
Detection Prevalence   0.3399    0.191   0.1582   0.1227  0.01966   0.1180  0.01592  0.01404
Balanced Accuracy      1.0000    1.000   1.0000   1.0000  1.00000   0.9995  1.00000  0.96875
                     Class: 9 Class: 90
Sensitivity          1.000000   1.00000
Specificity          1.000000   1.00000
Pos Pred Value       1.000000   1.00000
Neg Pred Value       1.000000   1.00000
Prevalence           0.004682   0.01592
Detection Rate       0.004682   0.01592
Detection Prevalence 0.004682   0.01592
Balanced Accuracy    1.000000   1.00000

無事、どちらも正答率がほぼ100%を実現することができました。 決定木を何本も生やしていくので、並列化を行わないとそこそこ時間がかかりますが、優秀な決定木のようです。

残念ながら、ここにもからくりはあるのですが、ひとまず今回は様々な分類器を紹介することを目的に次に進みます。

SVM

SVMも分類機の一つです。 再びWikipediaのSVMのページからの引用ですが、サポートベクターマシンとは右図のようにグループの間の境界線を見つける手法になります。

SVM
SVM

これだけでは至ってシンプルなのですが、SVMが有能である所以は、カーネルトリックを使うことで、左図のように曲がった境界線も見つけることができる点にあります。

境界線を曲げるにはカーネル関数を選択する必要がありますが、ここでは(おそらく)一般的に使われているradial kernelを使います。


vote_svm <- train(
  target ~ .,
  data = data_vote,
  method = "svmRadial",
  preProcess = c("center", "scale")
)

confusionMatrix(predict(vote_svm, data_vote), data_vote$target)

Confusion Matrix and Statistics

          Reference
Prediction    0    1
         0   14    0
         1  375 1074
                                          
               Accuracy : 0.7437          
                 95% CI : (0.7205, 0.7659)
    No Information Rate : 0.7341          
    P-Value [Acc > NIR] : 0.2127          
                                          
                  Kappa : 0.052           
                                          
 Mcnemar's Test P-Value : <2e-16          
                                          
            Sensitivity : 0.035990        
            Specificity : 1.000000        
         Pos Pred Value : 1.000000        
         Neg Pred Value : 0.741201        
             Prevalence : 0.265892        
         Detection Rate : 0.009569        
   Detection Prevalence : 0.009569        
      Balanced Accuracy : 0.517995        
                                          
       'Positive' Class : 0               
                                          

party_svm <- train(
  target ~ .,
  data = data_party,
  method = "svmRadial",
  preProcess = c("center", "scale")
)

confusionMatrix(predict(party_svm, data_party), data_party$target)

Confusion Matrix and Statistics

          Reference
Prediction   1   2   3   4   5   6   7   8   9  90
        1  349  78 111 114  19  38   4   4   5  12
        2   13 122  49  15   2  86  10  12   0   4
        3    1   4   9   2   0   1   3   0   0   1
        4    0   0   0   0   0   0   0   0   0   0
        5    0   0   0   0   0   0   0   0   0   0
        6    0   0   0   0   0   0   0   0   0   0
        7    0   0   0   0   0   0   0   0   0   0
        8    0   0   0   0   0   0   0   0   0   0
        9    0   0   0   0   0   0   0   0   0   0
        90   0   0   0   0   0   0   0   0   0   0

Overall Statistics
                                          
               Accuracy : 0.4494          
                 95% CI : (0.4193, 0.4798)
    No Information Rate : 0.3399          
    P-Value [Acc > NIR] : 8.263e-14       
                                          
                  Kappa : 0.2216          
                                          
 Mcnemar's Test P-Value : NA              

Statistics by Class:

                     Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
Sensitivity            0.9614   0.5980 0.053254   0.0000  0.00000    0.000  0.00000  0.00000
Specificity            0.4539   0.7789 0.986652   1.0000  1.00000    1.000  1.00000  1.00000
Pos Pred Value         0.4755   0.3898 0.428571      NaN      NaN      NaN      NaN      NaN
Neg Pred Value         0.9581   0.8914 0.847182   0.8773  0.98034    0.883  0.98408  0.98502
Prevalence             0.3399   0.1910 0.158240   0.1227  0.01966    0.117  0.01592  0.01498
Detection Rate         0.3268   0.1142 0.008427   0.0000  0.00000    0.000  0.00000  0.00000
Detection Prevalence   0.6873   0.2931 0.019663   0.0000  0.00000    0.000  0.00000  0.00000
Balanced Accuracy      0.7077   0.6885 0.519953   0.5000  0.50000    0.500  0.50000  0.50000
                     Class: 9 Class: 90
Sensitivity          0.000000   0.00000
Specificity          1.000000   1.00000
Pos Pred Value            NaN       NaN
Neg Pred Value       0.995318   0.98408
Prevalence           0.004682   0.01592
Detection Rate       0.000000   0.00000
Detection Prevalence 0.000000   0.00000
Balanced Accuracy    0.500000   0.50000

残念ながら(?)、今回は決定木と同じくらいの性能しか出ませんでした。

ニューラルネット

最後に、ニューラルネットについて紹介します。 ニューラルネットの話になるとしばしば右図のようなグラフを見るのですが(Stack Overflowのとあるページより)、これが何かを理解するために左図のロジットに戻ってみます。

ニューラルネット
ニューラルネット

左図のお気持ちとしては特徴量inputをたくさん入れて(+1は切片のことです)、目的変数がとあるクラスに入る確率を求めているということです。 右図では、入力層Layer L1と出力層Layer L3の間に中間層Layer L2が入っています。 つまり、特徴量をそのまま分類器に突っ込むのではなく、中間層でゴニョゴニョしてから分類機に入れているのです。

ニューラルネットもcaretで簡単に実装できます。


vote_nnet <- train(
  target ~ .,
  data = data_vote,
  method = "nnet",
  preProcess = c("center", "scale"),
  trace = FALSE
)

confusionMatrix(predict(vote_nnet, data_vote), data_vote$target)

Confusion Matrix and Statistics

          Reference
Prediction   0   1
         0 230 127
         1 159 947
                                          
               Accuracy : 0.8045          
                 95% CI : (0.7832, 0.8245)
    No Information Rate : 0.7341          
    P-Value [Acc > NIR] : 1.938e-10       
                                          
                  Kappa : 0.4858          
                                          
 Mcnemar's Test P-Value : 0.06679         
                                          
            Sensitivity : 0.5913          
            Specificity : 0.8818          
         Pos Pred Value : 0.6443          
         Neg Pred Value : 0.8562          
             Prevalence : 0.2659          
         Detection Rate : 0.1572          
   Detection Prevalence : 0.2440          
      Balanced Accuracy : 0.7365          
                                          
       'Positive' Class : 0               
                                          

party_nnet <- train(
  target ~ .,
  data = data_party,
  method = "nnet",
  preProcess = c("center", "scale"),
  trace = FALSE
)

confusionMatrix(predict(party_nnet, data_party), data_party$target)

Confusion Matrix and Statistics

          Reference
Prediction   1   2   3   4   5   6   7   8   9  90
        1  328  61  84 110  19  25   3   2   5   8
        2   35 143  85  21   2 100  14  14   0   9
        3    0   0   0   0   0   0   0   0   0   0
        4    0   0   0   0   0   0   0   0   0   0
        5    0   0   0   0   0   0   0   0   0   0
        6    0   0   0   0   0   0   0   0   0   0
        7    0   0   0   0   0   0   0   0   0   0
        8    0   0   0   0   0   0   0   0   0   0
        9    0   0   0   0   0   0   0   0   0   0
        90   0   0   0   0   0   0   0   0   0   0

Overall Statistics
                                         
               Accuracy : 0.441          
                 95% CI : (0.411, 0.4714)
    No Information Rate : 0.3399         
    P-Value [Acc > NIR] : 4.765e-12      
                                         
                  Kappa : 0.2226         
                                         
 Mcnemar's Test P-Value : NA             

Statistics by Class:

                     Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6 Class: 7 Class: 8
Sensitivity            0.9036   0.7010   0.0000   0.0000  0.00000    0.000  0.00000  0.00000
Specificity            0.5504   0.6759   1.0000   1.0000  1.00000    1.000  1.00000  1.00000
Pos Pred Value         0.5085   0.3381      NaN      NaN      NaN      NaN      NaN      NaN
Neg Pred Value         0.9173   0.9054   0.8418   0.8773  0.98034    0.883  0.98408  0.98502
Prevalence             0.3399   0.1910   0.1582   0.1227  0.01966    0.117  0.01592  0.01498
Detection Rate         0.3071   0.1339   0.0000   0.0000  0.00000    0.000  0.00000  0.00000
Detection Prevalence   0.6039   0.3961   0.0000   0.0000  0.00000    0.000  0.00000  0.00000
Balanced Accuracy      0.7270   0.6885   0.5000   0.5000  0.50000    0.500  0.50000  0.50000
                     Class: 9 Class: 90
Sensitivity          0.000000   0.00000
Specificity          1.000000   1.00000
Pos Pred Value            NaN       NaN
Neg Pred Value       0.995318   0.98408
Prevalence           0.004682   0.01592
Detection Rate       0.000000   0.00000
Detection Prevalence 0.000000   0.00000
Balanced Accuracy    0.500000   0.50000

SVMと大体同じという感じでした。