Spark RDD – Getting to the bottom records…

Apache Spark is the leading computation engine for Big Data and Analytics and has revolutionized the way we handle big data. Its API, the Resilient Distributed Dataset (RDD), is a powerful and robust tool which makes distributed computation tasks easy.

There are numerous advantages to Apache Spark, making it a top choice for many organizations. A minor problem and disadvantage that I have come across lately is related to extracting the bottom records of the RDD (i.e. last records in the collection). Getting the top N records of an RDD is straight forward, however trying to get the bottom records is surprisingly complex. This can be annoying for those of us who were already familiar with the simplicity of Pandas tail command.

A useful workaround is to sort the RDD so the records you want are on the top and then take the top N records. This will resolve most cases. If your data lacks a sort key, use the zipWithIndex() method to add an index per row and sort it accordingly. The code below provides an example of how to get the bottom 10 records of a RDD using zipWithIndex()

rdd.zipWithIndex().map({case (x,y) => (y,x)}).sortByKey(ascending = false).map({case (x,y) => y}).take(10).
reverse

The concern with using sort (either by sort key or zipWithIndex()) is that it requires a substantial shuffle of your RDD. The following Scala code will help to avoid this as much as possible, by looping on RDD partitions without a shuffle. The code is far from optimal as it traverses the RDD several times in order to perform its job, but at the least you will save yourself a resource intensive shuffle (also the RDD caching helps…!)

object RddBottomHelper {

  case class PartitionInfo(id: Int, startRecord: Long, lastRecord: Long  )
  case class RddInfo(size: Long, partitionsInto: List[PartitionInfo])

  /* geting the size as Long not as Int froum size...*/
  def getIteratorSize[T](iterator: Iterator[T]): Long = {
    iterator.map(x => 1L).sum
  }

  def getRddInfo[T](rdd: RDD[T]): RddInfo = {
    rdd.mapPartitionsWithIndex( (id, itr) => Seq( ( id, getIteratorSize(itr) ) ).toIterator )
      .sortByKey().collect()
      .foldLeft(RddInfo(0L, List.empty))({
        case ( info, (id, size)) =>
          RddInfo(info.size + size, PartitionInfo(id, info.size, info.size + size) :: info.partitionsInto )
      })
  }

  def bottom[T: ClassTag](rdd: RDD[T], numRecords: Int): Array[T] = {

    if (numRecords < 1) return Array[T]()

    rdd.cache()
    val rddInfo = getRddInfo(rdd)

    val endRecord = rddInfo.size - 1
    val startRecord = endRecord - numRecords + 1
    val mapInfo = rddInfo.partitionsInto.filter(x=> x.lastRecord > startRecord).map(x => (x.id, x)).toMap

    rdd.mapPartitionsWithIndex( (id, itr) =>
      if (mapInfo.contains(id)) {
        val dropRecords = startRecord - mapInfo.get(id).get.startRecord
        itr.drop(dropRecords.toInt)
      }
      else Iterator.empty
    ).collect()
  }
}

I hope this helps you to get the most out of Spark!



* SWI Big Data Expert, Michael Birch, has provided an alternative to the code above, which is easier to implement in non-Scala Spark interfaces (e.g. PySpark). It’s less efficient than the mapPartition solution above, but faster than the sort options (as it will not require a shuffle of the RDD).

def bottom[T: ClassTag](rdd: RDD[T], numRecords: Long): Array[T] = {
     rdd.cache()
     val n: Long = rdd.count()
     val startIndex: Long = n - numRecords

     rdd.zipWithIndex()
       .filter{ case (_, index) => index >= startIndex }
       .keys /* i.e. records */
       .collect()

  }

Suggested Articles

Speak Your Mind

*