[apache-spark] Spark SQL: apply aggregate functions to a list of columns

There are multiple ways of applying aggregate functions to multiple columns.

GroupedData class provides a number of methods for the most common functions, including count, max, min, mean and sum, which can be used directly as follows:

  • Python:

    df = sqlContext.createDataFrame(
        [(1.0, 0.3, 1.0), (1.0, 0.5, 0.0), (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2)],
        ("col1", "col2", "col3"))
    
    df.groupBy("col1").sum()
    
    ## +----+---------+-----------------+---------+
    ## |col1|sum(col1)|        sum(col2)|sum(col3)|
    ## +----+---------+-----------------+---------+
    ## | 1.0|      2.0|              0.8|      1.0|
    ## |-1.0|     -2.0|6.199999999999999|      0.7|
    ## +----+---------+-----------------+---------+
    
  • Scala

    val df = sc.parallelize(Seq(
      (1.0, 0.3, 1.0), (1.0, 0.5, 0.0),
      (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2))
    ).toDF("col1", "col2", "col3")
    
    df.groupBy($"col1").min().show
    
    // +----+---------+---------+---------+
    // |col1|min(col1)|min(col2)|min(col3)|
    // +----+---------+---------+---------+
    // | 1.0|      1.0|      0.3|      0.0|
    // |-1.0|     -1.0|      0.6|      0.2|
    // +----+---------+---------+---------+
    

Optionally you can pass a list of columns which should be aggregated

df.groupBy("col1").sum("col2", "col3")

You can also pass dictionary / map with columns a the keys and functions as the values:

  • Python

    exprs = {x: "sum" for x in df.columns}
    df.groupBy("col1").agg(exprs).show()
    
    ## +----+---------+
    ## |col1|avg(col3)|
    ## +----+---------+
    ## | 1.0|      0.5|
    ## |-1.0|     0.35|
    ## +----+---------+
    
  • Scala

    val exprs = df.columns.map((_ -> "mean")).toMap
    df.groupBy($"col1").agg(exprs).show()
    
    // +----+---------+------------------+---------+
    // |col1|avg(col1)|         avg(col2)|avg(col3)|
    // +----+---------+------------------+---------+
    // | 1.0|      1.0|               0.4|      0.5|
    // |-1.0|     -1.0|3.0999999999999996|     0.35|
    // +----+---------+------------------+---------+
    

Finally you can use varargs:

  • Python

    from pyspark.sql.functions import min
    
    exprs = [min(x) for x in df.columns]
    df.groupBy("col1").agg(*exprs).show()
    
  • Scala

    import org.apache.spark.sql.functions.sum
    
    val exprs = df.columns.map(sum(_))
    df.groupBy($"col1").agg(exprs.head, exprs.tail: _*)
    

There are some other way to achieve a similar effect but these should more than enough most of the time.

See also:

Examples related to apache-spark

Select Specific Columns from Spark DataFrame Select columns in PySpark dataframe What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Spark dataframe: collect () vs select () How does createOrReplaceTempView work in Spark? Spark difference between reduceByKey vs groupByKey vs aggregateByKey vs combineByKey Filter df when values matches part of a string in pyspark Filtering a pyspark dataframe using isin by exclusion Convert date from String to Date format in Dataframes

Examples related to dataframe

Trying to merge 2 dataframes but get ValueError How to show all of columns name on pandas dataframe? Python Pandas - Find difference between two data frames Pandas get the most frequent values of a column Display all dataframe columns in a Jupyter Python Notebook How to convert column with string type to int form in pyspark data frame? Display/Print one column from a DataFrame of Series in Pandas Binning column with python pandas Selection with .loc in python Set value to an entire column of a pandas dataframe

Examples related to apache-spark-sql

Select Specific Columns from Spark DataFrame Pyspark: Filter dataframe based on multiple conditions Select columns in PySpark dataframe What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Spark dataframe: collect () vs select () How does createOrReplaceTempView work in Spark? Filter df when values matches part of a string in pyspark Convert date from String to Date format in Dataframes Take n rows from a spark dataframe and pass to toPandas()

Examples related to aggregate-functions

Spark SQL: apply aggregate functions to a list of columns GROUP BY without aggregate function GROUP BY + CASE statement must appear in the GROUP BY clause or be used in an aggregate function Naming returned columns in Pandas aggregate function? Concatenate multiple result rows of one column into one, group by another column How to include "zero" / "0" results in COUNT aggregate? Apply multiple functions to multiple groupby columns Reason for Column is invalid in the select list because it is not contained in either an aggregate function or the GROUP BY clause Optimal way to concatenate/aggregate strings