Mersenne Twister でランダムに行をシャッフル (3)

Mersenne Twister でランダムに行をシャッフル (2) - ny23の日記 ではシャッフルの厳密性をスルーしてしまったが,以下にように簡単に示せる.n 個の要素からなる集合を,m 個の部分集合 (各集合 i の要素は k_i 個とする; 一時ファイルに保存) に分けて,部分集合をランダムに選んでその中からランダムに一要素を選ぶとすると,部分集合の選び方が \frac{n!}{k_{1}!k_{2}!\cdots k_{m}!}通り,各集合における要素の選び方が k_{i}! 通りなので,全部で n! 通りになって,これらが全て均等な確率で生じるので正しくシャッフルできることが分かる.というか,そもそも部分集合の作り方によらないので,最初に行を一時ファイルにランダムに保存する必要すらないのか(仮定あり.追記参照).

// mt-shuf__.cc
#include <unistd.h>
#include <sys/resource.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <vector>
#include <algorithm>

// random number generator
#ifdef USE_MT
#include <tr1/random>
struct rand_ {
  std::tr1::variate_generator <std::tr1::mt19937,
                               std::tr1::uniform_int <size_t> > gen;
  rand_ () : gen (std::tr1::mt19937 (), std::tr1::uniform_int <size_t> ()) {};
  long operator () (long max) { return gen (max); }
};
#else
struct rand_ { long operator() (long max) const { return random () % max; } };
#endif

#ifndef BUF_SIZE       // default chunk size for shuffling (1M)
#define BUF_SIZE (1 << 20)
#endif

#ifndef OPEN_MAX       // default # file descriptors
#define OPEN_MAX 20
#endif

#ifndef MAX_LINE_LEN   // maximum length of lines
#define MAX_LINE_LEN 65536
#endif

#ifndef TMP_DIR
#define TMP_DIR "/tmp" // default directory to store temporary files
#endif

int main (int argc, char** argv) {
  if (argc < 2) {
    std::fprintf (stderr, "Usage: %s [-S size] [-T tmp_dir] in\n", argv[0]);
    std::exit (1);
  }
  // random number generator
  rand_ gen;
  // handle options
  size_t       buf_size = BUF_SIZE;
  const char * tmp_dir  = TMP_DIR;
  char       * in       = 0;
  for (int i = 1; i < argc; ++i) {
    if (std::strcmp (argv[i], "-S") == 0) {
      buf_size = 1;
      char &c = argv[++i][std::strlen (argv[i])-1];
      char * err;
      switch (c) {
        case 'G': buf_size <<= 10;
        case 'M': buf_size <<= 10;
        case 'K': buf_size <<= 10; c = '\0';
        default:  buf_size *= std::strtol (argv[i], &err, 10);
          if (*err)
            {
            std::fprintf (stderr, "invalid size in -S argument: %s\n", argv[i]);
            std::exit (1);
          }
      }
    } else if (std::strcmp (argv[i], "-T") == 0) {
      tmp_dir = argv[++i];
    } else {
      in = argv[i];
    }
  }
  if (! in) std::fprintf (stderr, "no file specified\n");
  char tmp_fn[std::strlen (tmp_dir) + 12];
  // divide & conquer
  FILE * fp = std::fopen (in, "rb");
  if (! fp) {
    std::fprintf (stderr, "no such file: %s\n", in);
    std::exit (1);
  }
  // set # temporary files
  struct rlimit rlim;
  const size_t ntmp_lim
    = (getrlimit (RLIMIT_NOFILE, &rlim) == 0 ? rlim.rlim_cur : OPEN_MAX) - 4;
  std::fseek (fp, 0, SEEK_END);
  const size_t size = std::ftell (fp);
  size_t ntmp = std::min (size / buf_size + 1, ntmp_lim);
  std::fseek (fp, 0, SEEK_SET);
  std::vector <FILE*> tmpfps;
  tmpfps.reserve (ntmp);
  for (size_t i = 0; i < ntmp; ++i) {
    std::sprintf (&tmp_fn[0], "%s/shufXXXXXX", tmp_dir);
    tmpfps.push_back (fdopen (mkstemp (tmp_fn), "w+"));
    unlink (tmp_fn);
  }
  std::fclose (fp);
  // divide
  fp = std::fopen (in, "r");
  buf_size = size / ntmp + 1;
  char * buf = new char[buf_size + MAX_LINE_LEN];
  std::vector <const char *> lns;
  for (size_t read (0), bin (0); bin < ntmp; read = 0, ++bin) {
    if ((read = std::fread (&buf[0], sizeof (char), buf_size, fp)) == 0)
      break;
    if (std::feof (fp)) {
      if (buf[read-1] != '\n') {
        ++read;
        buf[read-1] = '\n';
        std::fprintf (stderr, "WARNING: line feeder is appended\n");
      }
    } else {
      std::fgets (&buf[read], MAX_LINE_LEN, fp);
      read += std::strlen (&buf[read]);
    }
    for (char *p (&buf[0]), *end (&buf[read]); p != end; *p = '\0', ++p) {
      lns.push_back (p);
      while (p != end && *p != '\n') ++p;
      if (p == end || p - lns.back () >= MAX_LINE_LEN) {
        std::fprintf (stderr, "ERROR: Buffer Overflow; increase MAX_LINE_LEN\n");
        std::exit (1);
      }
    }
    // shuffle lines
    std::random_shuffle (lns.begin (), lns.end (), gen);
    FILE * writer = tmpfps[bin];
    for (std::vector <const char *>::iterator it = lns.begin ();
         it != lns.end (); ++it)
      std::fprintf (writer, "%s\n", *it);
    std::fseek (writer, 0, SEEK_SET);
    lns.clear ();
  }
  delete [] buf;
  // conquer
  char line[MAX_LINE_LEN];
  while (! tmpfps.empty ()) {
    size_t bin = gen (tmpfps.size ());
    FILE * reader = tmpfps[bin];
    if (std::fgets (&line[0], MAX_LINE_LEN, reader) != NULL) {
      std::fwrite (&line[0], sizeof (char), std::strlen (line), stdout);
    } else {
      std::fclose (reader);
      tmpfps.erase (tmpfps.begin () + bin, tmpfps.begin () + bin + 1);
    }
  }
  return 0;
};
// g++ -DUSE_MT -march=core2 -m64 -o mt-shuf__ mt-shuf__.cc

結果を見てみる.

N=#lines (D=size) | mt-shuf__| 
------------------------------
10^5 (    2.35MB) |    0.05s |
10^6 (   23.35MB) |    0.46s |
10^7 (  234.08MB) |    4.66s |
10^8 (2,341.42MB) |   59.99s |

一時ファイルへの Random Write / Sequential Read/Write が Sequential Write になった分,有意に速くなった.一時ファイルを Disk の連続領域に配置するとか,乱数を引く回数を mt-shuf と同じ n 回にするとか(64bit の乱数を 32bit/32bit に分けて前者を一時ファイルのシャッフルに,後者を一時ファイルの選択に使うとか,しかし乱数の保存領域が問題か.一時ファイルの各行の先頭に保存するとかすれば良いか),細かい工夫はまだありそうだけど,オンメモリの GNU shuf 程度には速いし,この辺で良いかなと思う.
[追記] id:s-yata 氏にコメントして気づきましたが,このコードは行の長さが局所的に大きく偏っていないことを仮定しています.なので,Mersenne Twister でランダムに行をシャッフル (2) - ny23の日記 の方がより頑健に動作します.