[performance] Apache Spark: map vs mapPartitions?

Imp. TIP :

Whenever you have heavyweight initialization that should be done once for many RDD elements rather than once per RDD element, and if this initialization, such as creation of objects from a third-party library, cannot be serialized (so that Spark can transmit it across the cluster to the worker nodes), use mapPartitions() instead of map(). mapPartitions() provides for the initialization to be done once per worker task/thread/partition instead of once per RDD data element for example : see below.

val newRd = myRdd.mapPartitions(partition => {
  val connection = new DbConnection /*creates a db connection per partition*/

  val newPartition = partition.map(record => {
    readMatchingFromDB(record, connection)
  }).toList // consumes the iterator, thus calls readMatchingFromDB 

  connection.close() // close dbconnection here
  newPartition.iterator // create a new iterator
})

Q2. does flatMap behave like map or like mapPartitions?

Yes. please see example 2 of flatmap.. its self explanatory.

Q1. What's the difference between an RDD's map and mapPartitions

map works the function being utilized at a per element level while mapPartitions exercises the function at the partition level.

Example Scenario : if we have 100K elements in a particular RDD partition then we will fire off the function being used by the mapping transformation 100K times when we use map.

Conversely, if we use mapPartitions then we will only call the particular function one time, but we will pass in all 100K records and get back all responses in one function call.

There will be performance gain since map works on a particular function so many times, especially if the function is doing something expensive each time that it wouldn't need to do if we passed in all the elements at once(in case of mappartitions).

map

Applies a transformation function on each item of the RDD and returns the result as a new RDD.

Listing Variants

def map[U: ClassTag](f: T => U): RDD[U]

Example :

val a = sc.parallelize(List("dog", "salmon", "salmon", "rat", "elephant"), 3)
 val b = a.map(_.length)
 val c = a.zip(b)
 c.collect
 res0: Array[(String, Int)] = Array((dog,3), (salmon,6), (salmon,6), (rat,3), (elephant,8)) 

mapPartitions

This is a specialized map that is called only once for each partition. The entire content of the respective partitions is available as a sequential stream of values via the input argument (Iterarator[T]). The custom function must return yet another Iterator[U]. The combined result iterators are automatically converted into a new RDD. Please note, that the tuples (3,4) and (6,7) are missing from the following result due to the partitioning we chose.

preservesPartitioning indicates whether the input function preserves the partitioner, which should be false unless this is a pair RDD and the input function doesn't modify the keys.

Listing Variants

def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U]

Example 1

val a = sc.parallelize(1 to 9, 3)
 def myfunc[T](iter: Iterator[T]) : Iterator[(T, T)] = {
   var res = List[(T, T)]()
   var pre = iter.next
   while (iter.hasNext)
   {
     val cur = iter.next;
     res .::= (pre, cur)
     pre = cur;
   }
   res.iterator
 }
 a.mapPartitions(myfunc).collect
 res0: Array[(Int, Int)] = Array((2,3), (1,2), (5,6), (4,5), (8,9), (7,8)) 

Example 2

val x = sc.parallelize(List(1, 2, 3, 4, 5, 6, 7, 8, 9,10), 3)
 def myfunc(iter: Iterator[Int]) : Iterator[Int] = {
   var res = List[Int]()
   while (iter.hasNext) {
     val cur = iter.next;
     res = res ::: List.fill(scala.util.Random.nextInt(10))(cur)
   }
   res.iterator
 }
 x.mapPartitions(myfunc).collect
 // some of the number are not outputted at all. This is because the random number generated for it is zero.
 res8: Array[Int] = Array(1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 7, 7, 7, 9, 9, 10) 

The above program can also be written using flatMap as follows.

Example 2 using flatmap

val x  = sc.parallelize(1 to 10, 3)
 x.flatMap(List.fill(scala.util.Random.nextInt(10))(_)).collect

 res1: Array[Int] = Array(1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10) 

Conclusion :

mapPartitions transformation is faster than map since it calls your function once/partition, not once/element..

Further reading : foreach Vs foreachPartitions When to use What?

Examples related to performance

Why is 2 * (i * i) faster than 2 * i * i in Java? What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to check if a key exists in Json Object and get its value Why does C++ code for testing the Collatz conjecture run faster than hand-written assembly? Most efficient way to map function over numpy array The most efficient way to remove first N elements in a list? Fastest way to get the first n elements of a List into an Array Why is "1000000000000000 in range(1000000000000001)" so fast in Python 3? pandas loc vs. iloc vs. at vs. iat? Android Recyclerview vs ListView with Viewholder

Examples related to scala

Intermediate language used in scalac? Why does calling sumr on a stream with 50 tuples not complete Select Specific Columns from Spark DataFrame Joining Spark dataframes on the key Provide schema while reading csv file as a dataframe how to filter out a null value from spark dataframe Fetching distinct values on a column using Spark DataFrame Can't push to the heroku Spark - Error "A master URL must be set in your configuration" when submitting an app Add jars to a Spark Job - spark-submit

Examples related to apache-spark

Select Specific Columns from Spark DataFrame Select columns in PySpark dataframe What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Spark dataframe: collect () vs select () How does createOrReplaceTempView work in Spark? Spark difference between reduceByKey vs groupByKey vs aggregateByKey vs combineByKey Filter df when values matches part of a string in pyspark Filtering a pyspark dataframe using isin by exclusion Convert date from String to Date format in Dataframes

Examples related to rdd

How to create a DataFrame from a text file in Spark Spark - repartition() vs coalesce() Difference between DataFrame, Dataset, and RDD in Spark Spark specify multiple column conditions for dataframe join Spark read file from S3 using sc.textFile ("s3n://...) Spark: subtract two DataFrames How to convert rdd object to dataframe in spark What is the difference between cache and persist? Apache Spark: map vs mapPartitions?