[apache-spark] Removing duplicates from rows based on specific columns in an RDD/Spark DataFrame

Let's say I have a rather large dataset in the following form:

data = sc.parallelize([('Foo',41,'US',3),

What I would like to do is remove duplicate rows based on the values of the first,third and fourth columns only.

Removing entirely duplicate rows is straightforward:

data = data.distinct()

and either row 5 or row 6 will be removed

But how do I only remove duplicate rows based on columns 1, 3 and 4 only? i.e. remove either one one of these:


In Python, this could be done by specifying columns with .drop_duplicates(). How can I achieve the same in Spark/Pyspark?

This question is related to apache-spark apache-spark-sql pyspark

The answer is

The below programme will help you drop duplicates on whole , or if you want to drop duplicates based on certain columns , you can even do that:

import org.apache.spark.sql.SparkSession

object DropDuplicates {
def main(args: Array[String]) {
val spark =

import spark.implicits._

// create an RDD of tuples with some data
val custs = Seq(
  (1, "Widget Co", 120000.00, 0.00, "AZ"),
  (2, "Acme Widgets", 410500.00, 500.00, "CA"),
  (3, "Widgetry", 410500.00, 200.00, "CA"),
  (4, "Widgets R Us", 410500.00, 0.0, "CA"),
  (3, "Widgetry", 410500.00, 200.00, "CA"),
  (5, "Ye Olde Widgete", 500.00, 0.0, "MA"),
  (6, "Widget Co", 12000.00, 10.00, "AZ")
val customerRows = spark.sparkContext.parallelize(custs, 4)

// convert RDD of tuples to DataFrame by supplying column names
val customerDF = customerRows.toDF("id", "name", "sales", "discount", "state")

println("*** Here's the whole DataFrame with duplicates")



// drop fully identical rows
val withoutDuplicates = customerDF.dropDuplicates()

println("*** Now without duplicates")


// drop fully identical rows
val withoutPartials = customerDF.dropDuplicates(Seq("name", "state"))

println("*** Now without partial duplicates too")



From your question, it is unclear as-to which columns you want to use to determine duplicates. The general idea behind the solution is to create a key based on the values of the columns that identify duplicates. Then, you can use the reduceByKey or reduce operations to eliminate duplicates.

Here is some code to get you started:

def get_key(x):
    return "{0}{1}{2}".format(x[0],x[2],x[3])

m = data.map(lambda x: (get_key(x),x))

Now, you have a key-value RDD that is keyed by columns 1,3 and 4. The next step would be either a reduceByKey or groupByKey and filter. This would eliminate duplicates.

r = m.reduceByKey(lambda x,y: (x))

This is my Df contain 4 is repeated twice so here will remove repeated values.

scala> df.show
|    1|
|    4|
|    3|
|    5|
|    4|
|   18|

scala> val newdf=df.dropDuplicates

scala> newdf.show
|    1|
|    3|
|    5|
|    4|
|   18|

I know you already accepted the other answer, but if you want to do this as a DataFrame, just use groupBy and agg. Assuming you had a DF already created (with columns named "col1", "col2", etc) you could do:

myDF.groupBy($"col1", $"col3", $"col4").agg($"col1", max($"col2"), $"col3", $"col4")

Note that in this case, I chose the Max of col2, but you could do avg, min, etc.

Agree with David. To add on, it may not be the case that we want to groupBy all columns other than the column(s) in aggregate function i.e, if we want to remove duplicates purely based on a subset of columns and retain all columns in the original dataframe. So the better way to do this could be using dropDuplicates Dataframe api available in Spark 1.4.0

For reference, see: https://spark.apache.org/docs/1.4.0/api/scala/index.html#org.apache.spark.sql.DataFrame

I used inbuilt function dropDuplicates(). Scala code given below

val data = sc.parallelize(List(("Foo",41,"US",3),


Output :

|  x|  y|  z|count|
|Baz| 22| US|    6|
|Foo| 39| UK|    1|
|Foo| 41| US|    3|
|Bar| 57| CA|    2|

Examples related to apache-spark

Select Specific Columns from Spark DataFrame Select columns in PySpark dataframe What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Spark dataframe: collect () vs select () How does createOrReplaceTempView work in Spark? Spark difference between reduceByKey vs groupByKey vs aggregateByKey vs combineByKey Filter df when values matches part of a string in pyspark Filtering a pyspark dataframe using isin by exclusion Convert date from String to Date format in Dataframes

Examples related to apache-spark-sql

Select Specific Columns from Spark DataFrame Pyspark: Filter dataframe based on multiple conditions Select columns in PySpark dataframe What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Spark dataframe: collect () vs select () How does createOrReplaceTempView work in Spark? Filter df when values matches part of a string in pyspark Convert date from String to Date format in Dataframes Take n rows from a spark dataframe and pass to toPandas()

Examples related to pyspark

Pyspark: Filter dataframe based on multiple conditions How to convert column with string type to int form in pyspark data frame? Select columns in PySpark dataframe How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Filter df when values matches part of a string in pyspark Filtering a pyspark dataframe using isin by exclusion PySpark: withColumn() with two conditions and three outcomes How to get name of dataframe column in pyspark? Spark RDD to DataFrame python PySpark 2.0 The size or shape of a DataFrame