[sql] How to select the first row of each group?

I have a DataFrame generated as follow:

df.groupBy($"Hour", $"Category")
  .agg(sum($"value") as "TotalValue")
  .sort($"Hour".asc, $"TotalValue".desc))

The results look like:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
|   0|  cat105|       1.3|
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   1|   cat23|       5.3|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   2|   cat68|       9.8|
|   3|    cat8|      35.6|
| ...|    ....|      ....|
+----+--------+----------+

As you can see, the DataFrame is ordered by Hour in an increasing order, then by TotalValue in a descending order.

I would like to select the top row of each group, i.e.

  • from the group of Hour==0 select (0,cat26,30.9)
  • from the group of Hour==1 select (1,cat67,28.5)
  • from the group of Hour==2 select (2,cat56,39.6)
  • and so on

So the desired output would be:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
| ...|     ...|       ...|
+----+--------+----------+

It might be handy to be able to select the top N rows of each group as well.

Any help is highly appreciated.

This question is related to sql scala apache-spark dataframe apache-spark-sql

The answer is


Window functions:

Something like this should do the trick:

import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

This method will be inefficient in case of significant data skew.

Plain SQL aggregation followed by join:

Alternatively you can join with aggregated data frame:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
  .drop("max_hour")
  .drop("max_value")

dfTopByJoin.show

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

It will keep duplicate values (if there is more than one category per hour with the same total value). You can remove these as follows:

dfTopByJoin
  .groupBy($"hour")
  .agg(
    first("category").alias("category"),
    first("TotalValue").alias("TotalValue"))

Using ordering over structs:

Neat, although not very well tested, trick which doesn't require joins or window functions:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
  .groupBy($"hour")
  .agg(max("vs").alias("vs"))
  .select($"Hour", $"vs.Category", $"vs.TotalValue")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

With DataSet API (Spark 1.6+, 2.0+):

Spark 1.6:

case class Record(Hour: Integer, Category: String, TotalValue: Double)

df.as[Record]
  .groupBy($"hour")
  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
  .show

// +---+--------------+
// | _1|            _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 or later:

df.as[Record]
  .groupByKey(_.Hour)
  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

The last two methods can leverage map side combine and don't require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed output mode.

Don't use:

df.orderBy(...).groupBy(...).agg(first(...), ...)

It may seem to work (especially in the local mode) but it is unreliable (see SPARK-16207, credits to Tzach Zohar for linking relevant JIRA issue, and SPARK-30335).

The same note applies to

df.orderBy(...).dropDuplicates(...)

which internally uses equivalent execution plan.


For Spark 2.0.2 with grouping by multiple columns:

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)

val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

This is a exact same of zero323's answer but in SQL query way.

Assuming that dataframe is created and registered as

df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0   |cat26   |30.9      |
//|0   |cat13   |22.1      |
//|0   |cat95   |19.6      |
//|0   |cat105  |1.3       |
//|1   |cat67   |28.5      |
//|1   |cat4    |26.8      |
//|1   |cat13   |12.6      |
//|1   |cat23   |5.3       |
//|2   |cat56   |39.6      |
//|2   |cat40   |29.7      |
//|2   |cat187  |27.9      |
//|2   |cat68   |9.8       |
//|3   |cat8    |35.6      |
//+----+--------+----------+

Window function :

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Plain SQL aggregation followed by join:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
  "(select Hour, Category, TotalValue from table tmp1 " +
  "join " +
  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
  "on " +
  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
  "group by tmp3.Hour")
  .show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Using ordering over structs:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

DataSets way and don't dos are same as in original answer


The pattern is group by keys => do something to each group e.g. reduce => return to dataframe

I thought the Dataframe abstraction is a bit cumbersome in this case so I used RDD functionality

 val rdd: RDD[Row] = originalDf
  .rdd
  .groupBy(row => row.getAs[String]("grouping_row"))
  .map(iterableTuple => {
    iterableTuple._2.reduce(reduceFunction)
  })

val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)

A nice way of doing this with the dataframe api is using the argmax logic like so

  val df = Seq(
    (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
    (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
    (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
    (3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")

  df.groupBy($"Hour")
    .agg(max(struct($"TotalValue", $"Category")).as("argmax"))
    .select($"Hour", $"argmax.*").show

 +----+----------+--------+
 |Hour|TotalValue|Category|
 +----+----------+--------+
 |   1|      28.5|   cat67|
 |   3|      35.6|    cat8|
 |   2|      39.6|   cat56|
 |   0|      30.9|   cat26|
 +----+----------+--------+

The solution below does only one groupBy and extract the rows of your dataframe that contain the maxValue in one shot. No need for further Joins, or Windows.

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame

//df is the dataframe with Day, Category, TotalValue

implicit val dfEnc = RowEncoder(df.schema)

val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}

Here you can do like this -

   val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")

data.withColumnRenamed("_1","Hour").show

We can use the rank() window function (where you would choose the rank = 1) rank just adds a number for every row of a group (in this case it would be the hour)

here's an example. ( from https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank )

val dataset = spark.range(9).withColumn("bucket", 'id % 3)

import org.apache.spark.sql.expressions.Window
val byBucket = Window.partitionBy('bucket).orderBy('id)

scala> dataset.withColumn("rank", rank over byBucket).show
+---+------+----+
| id|bucket|rank|
+---+------+----+
|  0|     0|   1|
|  3|     0|   2|
|  6|     0|   3|
|  1|     1|   1|
|  4|     1|   2|
|  7|     1|   3|
|  2|     2|   1|
|  5|     2|   2|
|  8|     2|   3|
+---+------+----+

Questions with sql tag:

Passing multiple values for same variable in stored procedure SQL permissions for roles Generic XSLT Search and Replace template Access And/Or exclusions Pyspark: Filter dataframe based on multiple conditions Subtracting 1 day from a timestamp date PYODBC--Data source name not found and no default driver specified select rows in sql with latest date for each ID repeated multiple times ALTER TABLE DROP COLUMN failed because one or more objects access this column Create Local SQL Server database Export result set on Dbeaver to CSV How to create temp table using Create statement in SQL Server? SQL Query Where Date = Today Minus 7 Days How do I pass a list as a parameter in a stored procedure? #1273 – Unknown collation: ‘utf8mb4_unicode_520_ci’ MySQL Error: : 'Access denied for user 'root'@'localhost' SQL Server IF EXISTS THEN 1 ELSE 2 How to add a boolean datatype column to an existing table in sql? Presto SQL - Converting a date string to date format What is the meaning of <> in mysql query? Change Date Format(DD/MM/YYYY) in SQL SELECT Statement Convert timestamp to date in Oracle SQL #1292 - Incorrect date value: '0000-00-00' Postgresql tables exists, but getting "relation does not exist" when querying SQL query to check if a name begins and ends with a vowel Find the number of employees in each department - SQL Oracle Error in MySQL when setting default value for DATE or DATETIME Drop view if exists Could not find server 'server name' in sys.servers. SQL Server 2014 How to create a Date in SQL Server given the Day, Month and Year as Integers TypeError: tuple indices must be integers, not str Select Rows with id having even number SELECT list is not in GROUP BY clause and contains nonaggregated column IN vs ANY operator in PostgreSQL How to insert date values into table Error related to only_full_group_by when executing a query in MySql How to select the first row of each group? Connecting to Microsoft SQL server using Python eloquent laravel: How to get a row count from a ->get() How to execute raw queries with Laravel 5.1? In Oracle SQL: How do you insert the current date + time into a table? Extract number from string with Oracle function Rebuild all indexes in a Database SQL: Two select statements in one query DB2 SQL error sqlcode=-104 sqlstate=42601 What difference between the DATE, TIME, DATETIME, and TIMESTAMP Types How to run .sql file in Oracle SQL developer tool to import database? Concatenate columns in Apache Spark DataFrame How Stuff and 'For Xml Path' work in SQL Server? Fatal error: Call to a member function query() on null

Questions with scala tag:

Intermediate language used in scalac? Why does calling sumr on a stream with 50 tuples not complete Select Specific Columns from Spark DataFrame Joining Spark dataframes on the key Provide schema while reading csv file as a dataframe how to filter out a null value from spark dataframe Fetching distinct values on a column using Spark DataFrame Can't push to the heroku Spark - Error "A master URL must be set in your configuration" when submitting an app Add jars to a Spark Job - spark-submit How to sum the values of one column of a dataframe in spark/scala How to create a DataFrame from a text file in Spark Filter spark DataFrame on string contains java.io.IOException: Could not locate executable null\bin\winutils.exe in the Hadoop binaries. spark Eclipse on windows 7 Renaming column names of a DataFrame in Spark Scala Spark: Add column to dataframe conditionally Is there a way to take the first 1000 rows of a Spark Dataframe? How to select the first row of each group? How to save a spark DataFrame as csv on disk? dataframe: how to groupBy/count then filter on count in Scala Extract column values of Dataframe as List in Apache Spark Write single CSV file using spark-csv How to create an empty DataFrame with a specified schema? How do I check for equality using Spark Dataframe without SQL Query? How to define partitioning of DataFrame? Spark read file from S3 using sc.textFile ("s3n://...) How to save DataFrame directly to Hive? How to sort by column in descending order in Spark SQL? Spark - load CSV file as DataFrame? How to convert rdd object to dataframe in spark How can I change column types in Spark SQL's DataFrame? How do I skip a header from CSV files in Spark? Spark : how to run spark file from spark shell How to read files from resources folder in Scala? How to load local file in sc.textFile, instead of HDFS How to turn off INFO logging in Spark? Exception: Unexpected end of ZLIB input stream How do I convert csv file to rdd Print the data in ResultSet along with column names How to print the contents of RDD? Task not serializable: java.io.NotSerializableException when calling function outside closure only on classes not objects Apache Spark: map vs mapPartitions? Add element to a list In Scala I want to get the type of a variable at runtime Scala check if element is present in a list Install sbt on ubuntu Scala: join an iterable of strings Return in Scala Scala how can I count the number of occurrences in a list ScalaTest in sbt: is there a way to run a single test without tags?

Questions with apache-spark tag:

Select Specific Columns from Spark DataFrame Select columns in PySpark dataframe What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Spark dataframe: collect () vs select () How does createOrReplaceTempView work in Spark? Spark difference between reduceByKey vs groupByKey vs aggregateByKey vs combineByKey Filter df when values matches part of a string in pyspark Filtering a pyspark dataframe using isin by exclusion Convert date from String to Date format in Dataframes Joining Spark dataframes on the key PySpark: withColumn() with two conditions and three outcomes Provide schema while reading csv file as a dataframe how to filter out a null value from spark dataframe Spark RDD to DataFrame python Split Spark Dataframe string column into multiple columns Fetching distinct values on a column using Spark DataFrame Spark - SELECT WHERE or filtering? Convert spark DataFrame column to python list How to check Spark Version Convert pyspark string to date format Spark - Error "A master URL must be set in your configuration" when submitting an app PySpark: multiple conditions in when clause How to import multiple csv files in a single load? Converting Pandas dataframe into Spark dataframe error Concatenate two PySpark dataframes Filter Pyspark dataframe column with None value Add jars to a Spark Job - spark-submit Pyspark replace strings in Spark dataframe column How to sum the values of one column of a dataframe in spark/scala How to create a DataFrame from a text file in Spark how to loop through each row of dataFrame in pyspark multiple conditions for filter in spark data frames Filter spark DataFrame on string contains java.io.IOException: Could not locate executable null\bin\winutils.exe in the Hadoop binaries. spark Eclipse on windows 7 Renaming column names of a DataFrame in Spark Scala get specific row from spark dataframe Spark: Add column to dataframe conditionally Spark DataFrame groupBy and sort in the descending order (pyspark) Is there a way to take the first 1000 rows of a Spark Dataframe? How to change dataframe column names in pyspark? Spark SQL: apply aggregate functions to a list of columns How to select the first row of each group? Spark Dataframe distinguish columns with duplicated name How to join on multiple columns in Pyspark? How to show full column content in a Spark Dataframe? How do I add a new column to a Spark DataFrame (using PySpark)? Best way to get the max value in a Spark dataframe column How to save a spark DataFrame as csv on disk? How to add a constant column in a Spark DataFrame?

Questions with dataframe tag:

Trying to merge 2 dataframes but get ValueError How to show all of columns name on pandas dataframe? Python Pandas - Find difference between two data frames Pandas get the most frequent values of a column Display all dataframe columns in a Jupyter Python Notebook How to convert column with string type to int form in pyspark data frame? Display/Print one column from a DataFrame of Series in Pandas Binning column with python pandas Selection with .loc in python Set value to an entire column of a pandas dataframe Pandas create empty DataFrame with only column names Python: pandas merge multiple dataframes Spark dataframe: collect () vs select () 'DataFrame' object has no attribute 'sort' Remove Unnamed columns in pandas dataframe Convert float64 column to int64 in Pandas Python Pandas iterate over rows and access column names Display rows with one or more NaN values in pandas dataframe ValueError: Length of values does not match length of index | Pandas DataFrame.unique() Convert List to Pandas Dataframe Column Pandas Split Dataframe into two Dataframes at a specific row Pandas dataframe groupby plot Removing space from dataframe columns in pandas Get total of Pandas column Python - How to convert JSON File to Dataframe Strip / trim all strings of a dataframe Merge two dataframes by index pandas how to check dtype for all columns in a dataframe? Joining Spark dataframes on the key Provide schema while reading csv file as a dataframe Pandas group-by and sum PySpark 2.0 The size or shape of a DataFrame How to concatenate multiple column values into a single column in Panda dataframe Convert Pandas DataFrame to JSON format pandas dataframe convert column type to string or categorical How to add multiple columns to pandas dataframe in one assignment? Fetching distinct values on a column using Spark DataFrame How to Add Incremental Numbers to a New Column Using Pandas Pandas KeyError: value not in index How to split data into 3 sets (train, validation and test)? Split / Explode a column of dictionaries into separate columns with pandas Group dataframe and get sum AND count? Save Dataframe to csv directly to s3 Python Pandas dataframe fillna() only some columns in place how to sort pandas dataframe from one column PySpark: multiple conditions in when clause What is dtype('O'), in pandas? Filter Pyspark dataframe column with None value Truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all() How to create a DataFrame from a text file in Spark

Questions with apache-spark-sql tag:

Select Specific Columns from Spark DataFrame Pyspark: Filter dataframe based on multiple conditions Select columns in PySpark dataframe What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Spark dataframe: collect () vs select () How does createOrReplaceTempView work in Spark? Filter df when values matches part of a string in pyspark Convert date from String to Date format in Dataframes Take n rows from a spark dataframe and pass to toPandas() Joining Spark dataframes on the key PySpark: withColumn() with two conditions and three outcomes Provide schema while reading csv file as a dataframe how to filter out a null value from spark dataframe Split Spark Dataframe string column into multiple columns Fetching distinct values on a column using Spark DataFrame Spark - SELECT WHERE or filtering? Convert pyspark string to date format PySpark: multiple conditions in when clause How to import multiple csv files in a single load? Filter Pyspark dataframe column with None value How to create a DataFrame from a text file in Spark how to loop through each row of dataFrame in pyspark Join two data frames, select all columns from one and some columns from the other multiple conditions for filter in spark data frames Filter spark DataFrame on string contains Renaming column names of a DataFrame in Spark Scala get specific row from spark dataframe Spark: Add column to dataframe conditionally Spark DataFrame groupBy and sort in the descending order (pyspark) Spark SQL: apply aggregate functions to a list of columns How to select the first row of each group? Spark Dataframe distinguish columns with duplicated name How to join on multiple columns in Pyspark? How do I add a new column to a Spark DataFrame (using PySpark)? Best way to get the max value in a Spark dataframe column How to save a spark DataFrame as csv on disk? How to add a constant column in a Spark DataFrame? How to check if spark dataframe is empty? How to change a dataframe column from String type to Double type in PySpark? dataframe: how to groupBy/count then filter on count in Scala Extract column values of Dataframe as List in Apache Spark Filtering a spark dataframe based on date How to export data from Spark SQL to CSV Difference between DataFrame, Dataset, and RDD in Spark How to create an empty DataFrame with a specified schema? Concatenate columns in Apache Spark DataFrame How to export a table dataframe in PySpark to csv? How do I check for equality using Spark Dataframe without SQL Query? Spark specify multiple column conditions for dataframe join