List内に最も多く出現するオブジェクトを取得するメソッド

Twitterでこんな処理はどう書けばいいんだろう?というツイートがあって、Scalaの勉強がてらちょっと書いてみました。

仕様は以下の通り

  • 引数で任意の要素型のListを受け取り、そのList内に最も多く出現するオブジェクトを取得する。
  • 最も多く出現するオブジェクトが複数ある場合もあるので、戻り値は Set[要素型] とする。
  • 引数で渡されるListは空の場合もありうる。

肝は最大の個数の要素が複数あった場合にその全てを返すところですね。最初に見つかったやつだけでよければ groupBy と maxBy を使えば簡単に書けるので。

最初に書いたのはこんな感じです。

def most[A](seq: Seq[A]): Set[A] = seq.groupBy(identity).foldLeft((0, Set[A]())) {
  case ((c, _), (x, xs)) if (c < xs.length)  => (xs.length, Set(x))
  case ((c, s), (x, xs)) if (c == xs.length) => (c, s + x)
  case (r, _) => r
}._2

パターンマッチがネストしたデータ構造を指定できるのは恐ろしく便利ですね。

汎用性をあげるために引数は List じゃなくて Seq にしています。

その後に、groupBy した後にソートすればもう少し簡単にできるかな?と思って書いたのが次のコード

def most[A](seq: Seq[A]): Set[A] = seq.groupBy(identity).toSeq.sortBy(_._2.length * -1) match {
  case l@((_, c) :: _) => l.takeWhile(_._2.length == c.length).unzip._1.toSet
  case _ => Set()
}

短くはなったけど複雑になったような……微妙な違和感が残ります。

ここまで書いたところでJavaにも浮気してみました。

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

...

public <E> Set<E> most(final Iterable<E> iterable) {
    final Map<E, Integer> counters = new HashMap<E, Integer>();
    for (final E e: iterable) {
        counters.put(e, counters.containsKey(e) ? counters.get(e) + 1 : 1);
    }
    int max = 0;
    final Set<E> result = new HashSet<E>();
    for (final Entry<E, Integer> entry : counters.entrySet()) {
        if (entry.getValue() > max) {
            result.clear();
            result.add(entry.getKey());
            max = entry.getValue();
        } else if (entry.getValue() == max) {
            result.add(entry.getKey());
        }
    }
    return result;
}

素直に書いたらこんな感じになると思います。見事に手続き的ですね。

せっかくなのでguava-librariesを使ってJavaでももう少しFunctionalな感じで書いてみました。

import static java.util.Collections.emptySet;
import static com.google.common.collect.Collections2.transform;
import static com.google.common.collect.HashMultimap.create;
import static com.google.common.collect.Multimaps.index;
import static com.google.common.collect.Sets.newHashSet;

import java.util.Collection;
import java.util.Collections;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeMap;

import com.google.common.base.Functions;
import com.google.common.collect.Multimap;

...

public <E> Set<E> most(final Iterable<E> iterable) {
    if (!iterable.iterator().hasNext()) return emptySet();
    final Multimap<Integer, E> multi = create();
    for (final Entry<E, Collection<E>> e : index(iterable, Functions.<E>identity()).asMap().entrySet()) {
        multi.put(e.getValue().size(), e.getKey());
    }
    return newHashSet(new TreeMap<Integer, Collection<E>>(multi.asMap()).lastEntry().getValue());
}

ほんとはワンライナーでも書いたんですが、そちらはあまりにもキモかったのでお蔵入りにしました;-p

このコードを書くときに、NavigableMapを使用して要素数でグルーピングすれば lastEntry() で取得できる、というのを思いついたので、Scalaに転用すれば maxBy が使えるなと思い書いてみました。

ちなみに maxBy は scala2.9 からなので、2.8で利用できるようにこっそり追加しました。

// このimplicitメソッドで、2.8でも Traversable が maxBy を使用できるようになる
implicit def wrapI[A](t: Traversable[A]) = new AnyRef {
  def maxBy[B](f: (A) => B)(implicit c: Ordering[B]): A = 
    t.reduceLeft {(x, y) => if (c.gteq(f(x), f(y))) x else y}
}

def most[A](seq: Seq[A]): Set[A] =
  if (seq.isEmpty) Set() else seq.groupBy(identity).groupBy(_._2.length).maxBy(_._1)._2.unzip._1.toSet

見事ワンライナーになりました。…なりましたが……はたして見やすくなったのか……。

個人的には、もしプロダクトコードに書くとするなら一番最初に書いた奴がいいかな、と思ってます。

しかしこれ以外にもまだまだ色々な書き方ができそうですね。奥が深い。