Статьи

Распознаватель цифр Kaggle: попытка случайного леса Mahout

Ранее я писал о  подходе K-средних,  который использовал  Джен  и я, пытаясь решить проблему  с Digit Recognizer Kaggle,  и, остановившись с точностью около 80%, мы решили попробовать один из алгоритмов, предложенных в разделе  уроков,  —  случайный лес !

Сначала мы использовали  библиотеку случайных лесов clojure,  но изо всех сил пытались построить случайный лес из данных обучающего набора за разумное время, поэтому мы переключились на  версию Mahout,  основанную на  статье о случайных лесах Лео Бреймана .

Есть  действительно хороший пример, объясняющий, как ансамбли работают в блоге Factual,  который мы нашли весьма полезным, чтобы помочь нам понять, как должны работать случайные леса.

Одним из самых мощных методов машинного обучения, к которому мы обращаемся, является ансамбль. Методы ансамбля создают удивительно сильные модели из набора слабых моделей, называемых базовыми учениками, и обычно требуют гораздо меньших настроек по сравнению с такими моделями, как машины опорных векторов.

Большинство методов ансамбля используют деревья решений в качестве базовых учеников, и многие методы ансамбля, такие как случайные леса и Adaboost, являются специфическими для ансамблей деревьев.

Мы смогли адаптировать  BreimanExample,  включенный в раздел примеров репозитория Mahout, чтобы делать то, что мы хотели.

Для начала мы написали следующий код для построения случайного леса:

public class MahoutKaggleDigitRecognizer {
  public static void main(String[] args) throws Exception {
    String descriptor = "L N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N ";
    String[] trainDataValues = fileAsStringArray("data/train.csv");
 
    Data data = DataLoader.loadData(DataLoader.generateDataset(descriptor, false, trainDataValues), trainDataValues);
 
    int numberOfTrees = 100;
    DecisionForest forest = buildForest(numberOfTrees, data);
  }
 
  private static DecisionForest buildForest(int numberOfTrees, Data data) {
    int m = (int) Math.floor(Maths.log(2, data.getDataset().nbAttributes()) + 1);
 
    DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
    treeBuilder.setM(m);
 
    return new SequentialBuilder(RandomUtils.getRandom(), treeBuilder, data.clone()).build(numberOfTrees);
  }
 
  private static String[] fileAsStringArray(String file) throws Exception {
    ArrayList<String> list = new ArrayList<String>();
 
    DataInputStream in = new DataInputStream(new FileInputStream(file));
    BufferedReader br = new BufferedReader(new InputStreamReader(in));
 
    String strLine;
    br.readLine(); // discard top one (header)
    while ((strLine = br.readLine()) != null) {
      list.add(strLine);
    }
 
    in.close();
    return list.toArray(new String[list.size()]);
  }
}

Файл обучающих данных выглядит примерно так:

label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8...,pixel783
1,0,0,0,0,0,0,...,0
0,0,0,0,0,0,0,...,0

Таким образом, в этом случае метка находится в первом столбце, который представлен как  L  в дескрипторе, а следующие 784 столбца являются числовым значением пикселей в изображении (следовательно, 784  N в дескрипторе).

Мы говорим ему создать случайный лес, который содержит 100 деревьев, и, поскольку у нас есть конечное число категорий, запись может быть классифицирована, поскольку мы передаем  false  как 2-й аргумент (регрессия) DataLoader.generateDataSet .

Значение  m  определяет, сколько атрибутов (в данном случае значений пикселей) используется для построения каждого дерева, и, предположительно,  log 2 (number_of_attributes) + 1  является оптимальным значением для этого!

Затем мы написали следующий код для прогнозирования меток набора тестовых данных:

public class MahoutKaggleDigitRecognizer {
  public static void main(String[] args) throws Exception {
    ...
    String[] testDataValues = testFileAsStringArray("data/test.csv");
    Data test = DataLoader.loadData(data.getDataset(), testDataValues);
    Random rng = RandomUtils.getRandom();
 
    for (int i = 0; i < test.size(); i++) {
    Instance oneSample = test.get(i);
 
    double classify = forest.classify(test.getDataset(), rng, oneSample);
    int label = data.getDataset().valueOf(0, String.valueOf((int) classify));
 
    System.out.println("Label: " + label);
  }
 
  private static String[] testFileAsStringArray(String file) throws Exception {
    ArrayList<String> list = new ArrayList<String>();
 
    DataInputStream in = new DataInputStream(new FileInputStream(file));
    BufferedReader br = new BufferedReader(new InputStreamReader(in));
 
    String strLine;
    br.readLine(); // discard top one (header)
    while ((strLine = br.readLine()) != null) {
      list.add("-," + strLine);
    }
 
    in.close();
    return list.toArray(new String[list.size()]);
  }
}

Было несколько вещей, которые нам показались запутанными при разработке, как это сделать:

  1. Формат тестовых данных должен быть идентичен формату обучающих данных, состоящих из метки, за которой следуют 784 числовых значения. Очевидно, что с тестовыми данными у нас нет метки, поэтому Mahout исключает нас, чтобы передать ‘-‘, куда метка пошла бы, иначе она выдаст исключение, которое объясняет ‘-‘ в   строке list.add .
  2. Сначала мы думали, что значение, возвращаемое  атрибутом forest.classify,  является предсказанием, но на самом деле это индекс, который нам необходимо найти в наборе данных.

Когда мы запустили этот алгоритм для набора тестовых данных с 10 деревьями, мы получили точность 83,8%, с 50 деревьями мы получили 84,4%, с 100 деревьями мы получили 96,28%, а с 200 деревьями мы получили 96,33%, и именно здесь в настоящее время достиг пика.

Время, затрачиваемое на построение лесов по мере увеличения количества деревьев, также начинает становиться проблемой, поэтому наш следующий шаг — либо найти способ распараллелить создание леса, либо выполнить какое-то  извлечение объектов  для попытаться улучшить точность.

Код на GitHub ,  если вы заинтересованы в играть с ним или есть какие — либо предложения о том, как улучшить его.