[apache-spark] Spark SQL: apply aggregate functions to a list of columns

Is there a way to apply an aggregate function to all (or a list of) columns of a dataframe, when doing a groupBy? In other words, is there a way to avoid doing this for every column:

df.groupBy("col1")
  .agg(sum("col2").alias("col2"), sum("col3").alias("col3"), ...)

The answer is


Current answers are perfectly correct on how to create the aggregations, but none actually address the column alias/renaming that is also requested in the question.

Typically, this is how I handle this case:

val dimensionFields = List("col1")
val metrics = List("col2", "col3", "col4")
val columnOfInterests = dimensions ++ metrics

val df = spark.read.table("some_table"). 
    .select(columnOfInterests.map(c => col(c)):_*)
    .groupBy(dimensions.map(d => col(d)): _*)
    .agg(metrics.map( m => m -> "sum").toMap)
    .toDF(columnOfInterests:_*)    // that's the interesting part

The last line essentially renames every columns of the aggregated dataframe to the original fields, essentially changing sum(col2) and sum(col3) to simply col2 and col3.


Another example of the same concept - but say - you have 2 different columns - and you want to apply different agg functions to each of them i.e

f.groupBy("col1").agg(sum("col2").alias("col2"), avg("col3").alias("col3"), ...)

Here is the way to achieve it - though I do not yet know how to add the alias in this case

See the example below - Using Maps

val Claim1 = StructType(Seq(StructField("pid", StringType, true),StructField("diag1", StringType, true),StructField("diag2", StringType, true), StructField("allowed", IntegerType, true), StructField("allowed1", IntegerType, true)))
val claimsData1 = Seq(("PID1", "diag1", "diag2", 100, 200), ("PID1", "diag2", "diag3", 300, 600), ("PID1", "diag1", "diag5", 340, 680), ("PID2", "diag3", "diag4", 245, 490), ("PID2", "diag2", "diag1", 124, 248))

val claimRDD1 = sc.parallelize(claimsData1)
val claimRDDRow1 = claimRDD1.map(p => Row(p._1, p._2, p._3, p._4, p._5))
val claimRDD2DF1 = sqlContext.createDataFrame(claimRDDRow1, Claim1)

val l = List("allowed", "allowed1")
val exprs = l.map((_ -> "sum")).toMap
claimRDD2DF1.groupBy("pid").agg(exprs) show false
val exprs = Map("allowed" -> "sum", "allowed1" -> "avg")

claimRDD2DF1.groupBy("pid").agg(exprs) show false

Examples related to apache-spark

Select Specific Columns from Spark DataFrame Select columns in PySpark dataframe What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Spark dataframe: collect () vs select () How does createOrReplaceTempView work in Spark? Spark difference between reduceByKey vs groupByKey vs aggregateByKey vs combineByKey Filter df when values matches part of a string in pyspark Filtering a pyspark dataframe using isin by exclusion Convert date from String to Date format in Dataframes

Examples related to dataframe

Trying to merge 2 dataframes but get ValueError How to show all of columns name on pandas dataframe? Python Pandas - Find difference between two data frames Pandas get the most frequent values of a column Display all dataframe columns in a Jupyter Python Notebook How to convert column with string type to int form in pyspark data frame? Display/Print one column from a DataFrame of Series in Pandas Binning column with python pandas Selection with .loc in python Set value to an entire column of a pandas dataframe

Examples related to apache-spark-sql

Select Specific Columns from Spark DataFrame Pyspark: Filter dataframe based on multiple conditions Select columns in PySpark dataframe What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Spark dataframe: collect () vs select () How does createOrReplaceTempView work in Spark? Filter df when values matches part of a string in pyspark Convert date from String to Date format in Dataframes Take n rows from a spark dataframe and pass to toPandas()

Examples related to aggregate-functions

Spark SQL: apply aggregate functions to a list of columns GROUP BY without aggregate function GROUP BY + CASE statement must appear in the GROUP BY clause or be used in an aggregate function Naming returned columns in Pandas aggregate function? Concatenate multiple result rows of one column into one, group by another column How to include "zero" / "0" results in COUNT aggregate? Apply multiple functions to multiple groupby columns Reason for Column is invalid in the select list because it is not contained in either an aggregate function or the GROUP BY clause Optimal way to concatenate/aggregate strings