C++0x で多クラス分類器を実装してみた

C++ の次期標準,C++0x (C++11) の標準案が ISO に全会一致で承認されて一ヶ月半ほど経つので C++0x でプログラムを書いてみることにした.折角なので,以前から実装しようと思っていたサポートクラスに基づく多クラス Passive Aggressive アルゴリズム

を実装してみる.
コンパイラには現時点で C++0x に最も対応していると期待できる GCC 4.7 (SVN 先端; 20110927) を利用.GCCC++0x の対応状況は以下を参照.

あまり時間をかけるつもりもなかったので,

をさらっと眺めて,ダウンロード可能な C++0x の Working Draft の最新版

GCC によるライブラリの実装を参照しつつ,C++0x の新機能をなるべく多く網羅するよう意識して(ただし新機能を使うことで効率が落ちないよう留意して)コードを書いた.理解が不十分なまま,やや無理やり新機能を使うようにしたので,不自然な使い方になっている箇所も多々あると思う(後で少しずつ直す予定).右辺値参照については,後で以下を読むことにしようかな.

// spa.cc: support-class passive aggressive; GNU GPL version 2 copyright@ny23
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <array>
#include <vector>
#include <string>
#include <memory>
#include <random>
#include <chrono>
#include <numeric>
#include <algorithm>
#include <unordered_map>

// type alias
typedef std::string label_t;
typedef std::pair <size_t, double>  l_t;               // loss
typedef std::vector <std::pair <size_t, double>> fv_t; // feature vector

struct ex_t final {
  size_t y;
  fv_t   fv;
  double l2 = 0;
  ex_t& operator= (const ex_t&) = delete;
  ex_t (const ex_t&)            = delete;
  ex_t (size_t y_, const fv_t & fv_) : y (y_), fv (fv_)
  { for (const auto &fn : fv) l2 += fn.second * fn.second; }
};

auto main (int argc, char ** argv) -> int {
  static_assert (__GNUC__ == 4 && __GNUC_MINOR__ >= 7, "tested with GCC 4.7");
  if (argc < 6) {
    std::fprintf (stderr, "Usage PA|PA1|PA2 %s train test c iter\n", argv[0]);
    std::exit (1);
  }
  // read command-line
  constexpr std::array <const char *, 3> algo_ss = { {"PA", "PA1", "PA2"} };
  enum struct algo_t : size_t { PA, PA1, PA2, };
  size_t found = 0;
  while (found < algo_ss.size () && std::strcmp (algo_ss[found], argv[1]) != 0)
    ++found;
  if (found == algo_ss.size ())
    { std::fprintf (stderr, "unknown optimizer: %s\n", argv[1]); std::exit (1); }
  const algo_t algo = static_cast <algo_t> (found);
  const char * train (argv[2]), * test (argv[3]);
  const double c    = std::strtod (argv[4], nullptr);
  const size_t iter = std::strtol (argv[5], nullptr, 10);
  //
  std::unordered_map <label_t, size_t>      l2i;  // label -> id
  std::vector <std::pair <label_t, size_t>> l2i_;
  // read data
  auto start = std::chrono::steady_clock::now ();
  std::vector <std::unique_ptr <const ex_t>> ex;
  size_t fmax = 0;
  size_t read = 0;
  char * line = nullptr;
  FILE * fp   = std::fopen (train, "r");
  std::fprintf (stderr, "read: ");
  for (fv_t fv; (line = fgetln (fp, &read)) != nullptr; fv.clear ()) {
    char * p (line), * const p_end (line + read - 1);
    while (p != p_end && ! std::isspace (*p)) ++p; *p = '\0';
    const auto &&itb = l2i.insert ({line, l2i.size ()}); // no emplace; gcc 4.7
    const size_t   y = itb.first->second;
    if (itb.second) l2i_.emplace_back (line, y);
    fv.clear ();
    while (p != p_end) {
      const size_t fi = std::strtol (++p, &p, 10); // *p -> ':'
      const double v  = std::strtod (++p, &p);     // *p -> ' ' || '\t' || p_end
      fv.emplace_back (fi, v);
    }
    fmax = std::max (fmax, fv.back ().first);
    ex.emplace_back (new ex_t {y, fv});
  }
  std::fclose (fp);
  ex.shrink_to_fit ();
  // shuffle
  std::shuffle (ex.begin (), ex.end (), std::mt19937 ());
  auto d = std::chrono::steady_clock::now () - start;
  std::fprintf (stderr, "%.3fs\n", std::chrono::duration <double> (d).count ());
  // training
  start = std::chrono::steady_clock::now ();
  std::fprintf (stderr, "train: ");
  const size_t nc = l2i.size (); // number of classes
  std::vector <size_t> cs (nc); std::iota (cs.begin (), cs.end (), 0);
  std::vector <double> ms (nc), ws (nc * (fmax + 1));
#ifdef USE_AVERAGING
  std::vector <double> wsa (nc * (fmax + 1));
#endif
  std::vector <l_t>    ls;
  size_t nex = 0;
  for (size_t i = 0; i < iter; ++i) {
    for (const auto &it : ex) {
      ls.clear ();
      for (const size_t j : cs)
        if (j != it->y) ls.emplace_back (j, 0.0);
      // compute margin
      std::fill (ms.begin (), ms.end (), 0.0);
      for (const auto &fn : it->fv)
        for (size_t j = 0; j < nc; ++j) // range-based for seems slow here
          ms[j] += ws[nc * fn.first + j] * fn.second;
      // compute loss
      for (l_t &lj : ls)
        lj.second = std::max (0.0, 1 - (ms[it->y] - ms[lj.first]));
      // examine support class
      std::sort (ls.begin (), ls.end (),
                 [] (const l_t& x, const l_t&y) { return x.second > y.second; });
      size_t k = 0;
      double l = 0; // summation of loss
      for (bool is_sc = true; k < ls.size (); l += ls[k].second, ++k) {
        switch (algo) {
          case algo_t::PA:
            is_sc &= l < (k + 1) * ls[k].second; break;
          case algo_t::PA1:
            is_sc &= l < std::min (k * ls[k].second + c * it->l2, (k + 1) * ls[k].second); break;
          case algo_t::PA2:
            is_sc &= l < ((k + 1) * it->l2 + k / (2 * c)) / (it->l2 + 1 / (2 * c)) * ls[k].second; break;
        }
        if (! is_sc) { ls.resize (k); break; }
      } // k = |S|; support class 0 -> k - 1
      // update weights
      double penalty = 0;
      switch (algo) {
        case algo_t::PA:
          penalty = l / (k + 1); break;
        case algo_t::PA1:
          penalty = std::max (l / k - c * it->l2 / k, l / (k + 1)); break;
        case algo_t::PA2:
          penalty = (it->l2 + 1 / (2 * c)) / ((k + 1) * it->l2 + (k / (2 * c))) * l; break;
      }
      for (const l_t &lj : ls) { // for each support class
        const double t = std::max (0.0, lj.second - penalty) / it->l2;
        for (const auto &fn : it->fv) {
          ws[nc * fn.first + it->y]     += t * fn.second;
          ws[nc * fn.first + lj.first]  -= t * fn.second;
#ifdef USE_AVERAGING
          wsa[nc * fn.first + it->y]    += nex * t * fn.second;
          wsa[nc * fn.first + lj.first] -= nex * t * fn.second;
#endif
        }
      }
      ++nex;
    }
    std::fprintf (stderr, ".");
  }
#ifdef USE_AVERAGING
  for (size_t fi = 0; fi < ws.size (); ++fi) ws[fi] -= wsa[fi] / (nex + 1);
#endif
  d = std::chrono::steady_clock::now () - start;
  std::fprintf (stderr, "%.3fs\n", std::chrono::duration <double> (d).count ());
  // instant testing
  std::fprintf (stderr, "test: ");
  start = std::chrono::steady_clock::now ();
  fp = std::fopen (test, "r");
  std::vector <std::vector <size_t> > res (nc);
  std::fill (res.begin (), res.end (), std::vector <size_t> (nc));
  size_t corr (0), total (0);
  while ((line = fgetln (fp, &read)) != nullptr) {
    char * p (line), * const p_end (line + read - 1);
    while (p != p_end && ! std::isspace (*p)) ++p; *p = '\0';
    const auto it = l2i.find (line);
    if (it == l2i.end ())
      { std::fprintf (stderr, "unseen label: %s\n", line); std::exit (1); }
    const size_t y = it->second;
    std::fill (ms.begin (), ms.end (), 0.0);
    while (p != p_end) {
      const size_t fi = std::strtol (++p, &p, 10);
      if (fi > fmax) break;
      const double v  = std::strtod (++p, &p);
      for (const size_t j : cs) ms[j] += ws[nc * fi + j] * v;
    }
    const size_t y_ = std::max_element (ms.cbegin (), ms.cend ()) - ms.cbegin ();
    ++res[y_][y];
    if (y_ == y) ++corr; ++total;
  }
  std::fclose (fp);
  d = std::chrono::steady_clock::now () - start;
  std::fprintf (stderr, "%.3fs\n", std::chrono::duration <double> (d).count ());
  // display rsults; dirty as hell
  std::sort (l2i_.begin (), l2i_.end ());
  std::fprintf (stderr, " class\t");
  for (auto i : l2i_) std::fprintf (stderr, " %7s", i.first.c_str ());
  std::fprintf (stderr, "         precision\n\n");
  for (auto i_ : l2i_) { // for each predicted labels
    std::fprintf (stderr, "%6s\t", i_.first.c_str ());
    const size_t y_ = i_.second;
    for (auto i : l2i_) std::fprintf (stderr, " %7ld", res[y_][i.second]);
    const size_t sum = std::accumulate (&res[y_][0], &res[y_][nc], 0);
    std::fprintf (stderr, " %9ld (%.3f)\n", sum, res[y_][y_] * 1.0 / sum);
  }
  std::fprintf (stderr, "\n\t");
  for (size_t i = 0; i < nc; ++i)
    for (size_t j = i + 1; j < nc; ++j) std::swap (res[i][j], res[j][i]);
  for (auto i : l2i_) {
    const size_t y = i.second;
    const size_t sum = std::accumulate (&res[y][0], &res[y][nc], 0L);
    std::fprintf (stderr, " %7ld", sum);
  }
  std::fprintf (stderr, " %9ld\n", total);
  std::fprintf (stderr, "recall\t ");
  for (auto i : l2i_) {
    const size_t y = i.second;
    const size_t sum = std::accumulate (&res[y][0], &res[y][nc], 0L);
    std::fprintf (stderr, " (%.3f)", res[y][y] * 1.0 / sum);
  }
  std::fprintf (stderr, "          (%.3f)\n", corr * 1.0 / total);
  return 0;
}

見ての通りの auto / range-based for 祭り.出力のところのコードがごちゃごちゃしているが,アルゴリズム本体の実装はとても簡単.Passive Aggressive を C++ で実装したときは(PA-I だけ実装して)ちょうど 100 行ぐらいだったので,そんなに変わりない感じかな.
実行してみる.まずは covtype データセット

# MacBook Air (Mid 2011); Intel Core i7 1.7 Ghz CPU, 4 GB Memory

> wget -O - http://www.csie.ntu.edu.tw/\~cjlin/libsvmtools/datasets/multiclass/covtype.scale01.bz2 2> /dev/null | bzip2 -dc | shuf | ruby -pe 'gsub(/\s+$/) {"\n"}' >! covtype         
> wc -l covtype      
581012 covtype
> tail -50000 covtype >! covtype.test && tail -100000 covtype | head -50000 >! covtype.dev && head -481012 covtype >! covtype.train

# support-class passive aggressive
> g++ -std=c++0x -Wall -O2 spa.cc -o spa
> run spa PA1 covtype.train covtype.test 0.01 20
read: 0.843s
train: ....................3.458s
test: 0.073s
 class	       1       2       3       4       5       6       7         precision

     1	   11717    3802       0       0       1       0     747     16267 (0.720)
     2	    6025   20220     407       0     786     491      14     27943 (0.724)
     3	       8     299    2697     172      23     833       1      4033 (0.669)
     4	       0       0      21      53       0       6       0        80 (0.662)
     5	       0       1       0       0       0       0       0         1 (0.000)
     6	       5      26      47       8       5     127       0       218 (0.583)
     7	     451      33       0       0       0       0     974      1458 (0.668)

	   18206   24381    3172     233     815    1457    1736     50000
recall	  (0.644) (0.829) (0.850) (0.227) (0.000) (0.087) (0.561)          (0.716)
elapsed (real): 4.709s; RSS=123.3M

# liblinear 1.8
> run train -c 5.0 -s 4 covtype.train m
............*......................*.............................................*.....................................................................*..................................................................*......**.*
optimization finished, #iter = 2212
Objective value = -1522866.266603
nSV = 769065
elapsed (real): 92.705s; RSS=159.8M

> run predict covtype.test m out 
Accuracy = 72.156% (36078/50000)
elapsed (real): 0.142s; RSS=0.5M

次に mnist データセット

# MacBook Air (Mid 2011); Intel Core i7 1.7 Ghz CPU, 4 GB Memory
> wget -O - http://www.csie.ntu.edu.tw/\~cjlin/libsvmtools/datasets/multiclass/mnist.scale.bz2 2> /dev/null | bzip2 -dc | shuf | ruby -pe 'gsub(/\s+$/) {"\n"}' >! mnist.train
> wget -O - http://www.csie.ntu.edu.tw/\~cjlin/libsvmtools/datasets/multiclass/mnist.scale.t.bz2 2> /dev/null | bzip2 -dc | shuf | ruby -pe 'gsub(/\s+$/) {"\n"}' >! mnist.test

# support-class passive aggressive
> run spa PA1 mnist.train mnist.test 0.0005 20
read: 1.170s
train: ....................2.323s
test: 0.186s
 class	       0       1       2       3       4       5       6       7       8       9         precision

     0	     962       0       7       1       2      10       9       3       4      10      1008 (0.954)
     1	       0    1117       7       0       2       2       3       5       7       6      1149 (0.972)
     2	       0       3     932      19       2       3       4      24       5       1       993 (0.939)
     3	       2       2      16     937       2      38       2       9      21      11      1040 (0.901)
     4	       0       0       9       1     921      13       9       8       9      38      1008 (0.914)
     5	       2       1       4      12       0     755      12       1      21       7       815 (0.926)
     6	      10       4      11       2      11      17     915       0      10       0       980 (0.934)
     7	       2       2       9       8       3      11       1     948       8      16      1008 (0.940)
     8	       2       6      34      25      13      39       3       7     885      16      1030 (0.859)
     9	       0       0       3       5      26       4       0      23       4     904       969 (0.933)

	     980    1135    1032    1010     982     892     958    1028     974    1009     10000
recall	  (0.982) (0.984) (0.903) (0.928) (0.938) (0.846) (0.955) (0.922) (0.909) (0.896)          (0.928)
elapsed (real): 3.792s; RSS=156.7M

# liblinear 1.8
> run train -c 0.05 -s 4 mnist.train m
..*...*.....*...*...*.*
optimization finished, #iter = 174
Objective value = -591.214076
nSV = 35455
elapsed (real): 5.439s; RSS=148.9M

> run predict mnist.test m out                                          
Accuracy = 92.96% (9296/10000)
elapsed (real): 0.334s; RSS=0.6M

covtype の方ではパラメタチューニングの関係で,liblinear の学習がかなり遅くなっているが,基本的には support-class PA の方が速いと言って良さそうだ.精度は同程度.クラス数が多い時には結果の表示がすごいことになるが,速度に関しては,十分実用的に使えるものになっている.
C++0x でプログラムを書いた印象としては.

  • 型推論 (auto), range-based for,初期化リストなどにより,実装コスト(とコード量)はかなり削減でき,プログラムの印象は大きく変わる.さらに,lambda 式(匿名関数)により,関数オブジェクト用のクラスを戻り書きする必要がなくなり,小さなプログラムであれば流れるように書くことが可能になる.auto による型推論は強力だが,何でもかんでも推論に任せるのはやや抵抗を感じる.typedef による alias を説明文字列として使った方が良い場面もあるように思う.また,匿名関数では関数名をコメント代わりに使えなくなるので,乱用すると可読性が下がりそう.Lisp (Scheme) でプログラムを覚えた身としては乱用したくなるのだけど.
  • 右辺値参照や placement insert (emplace*) により,高価なオブジェクトのコピーを回避し易くなり,実行効率が顕著に改善する.emplace* は今まで何故なかったのが不思議なところ.コンテナに insert する際にいちいち std::make_pair だの *::value_type だのしなくて済むのは素敵だ.前述の型推論とも相まって,型を明示的に書く機会は変数宣言部を除けば殆ど無くなりそう.
  • 新たに追加されたライブラリの中では,std::unordered_map(ハッシュ), std::mt19937(乱数), std::regex正規表現),std::chrono(時間計測),std::unique_ptr(メモリ管理; std::shared_ptr はオーバヘッドが結構あるようだ)は使う機会が多そうだ.ただし,正規表現は 最新版の GCC でもまともにサポートされていないので試していない.その他現時点で GCC でサポートされていない機能の中では,template typedef や委譲コンストラクタなどは使う機会が出てきそうだ.

使いやすくなることは間違いないのだけど,C++0x は変更点が大きく普及にはかなり時間がかかりそうな気がした.それまで現役でプログラムを書いているかどうか怪しいところだなぁ.
[追記] オンラインで訓練例をシャッフルする場合,訓練例へのポインタを保持したコンテナを std::shuffle すると訓練時間が大きく増加する.一方,訓練例自体を保持したコンテナを std::shuffle すると訓練時間の増加はある程度抑えられるが,メモリを多く消費する(コンテナの一要素辺りのサイズが増える関係で,倍々で拡張するときに無駄が出るため).偏りのない入力がくると仮定して,オンラインでシャッフルしないのが正解かな.