[python] Best way to get the max value in a Spark dataframe column

I'm trying to figure out the best way to get the largest value in a Spark dataframe column.

Consider the following example:

df = spark.createDataFrame([(1., 4.), (2., 5.), (3., 6.)], ["A", "B"])
df.show()

Which creates:

+---+---+
|  A|  B|
+---+---+
|1.0|4.0|
|2.0|5.0|
|3.0|6.0|
+---+---+

My goal is to find the largest value in column A (by inspection, this is 3.0). Using PySpark, here are four approaches I can think of:

# Method 1: Use describe()
float(df.describe("A").filter("summary = 'max'").select("A").first().asDict()['A'])

# Method 2: Use SQL
df.registerTempTable("df_table")
spark.sql("SELECT MAX(A) as maxval FROM df_table").first().asDict()['maxval']

# Method 3: Use groupby()
df.groupby().max('A').first().asDict()['max(A)']

# Method 4: Convert to RDD
df.select("A").rdd.max()[0]

Each of the above gives the right answer, but in the absence of a Spark profiling tool I can't tell which is best.

Any ideas from either intuition or empiricism on which of the above methods is most efficient in terms of Spark runtime or resource usage, or whether there is a more direct method than the ones above?

This question is related to python apache-spark pyspark apache-spark-sql

The answer is


in pyspark you can do this:

max(df.select('ColumnName').rdd.flatMap(lambda x: x).collect())

Remark: Spark is intended to work on Big Data - distributed computing. The size of the example DataFrame is very small, so the order of real-life examples can be altered with respect to the small ~ example.

Slowest: Method_1, because .describe("A") calculates min, max, mean, stddev, and count (5 calculations over the whole column)

Medium: Method_4, because, .rdd (DF to RDD transformation) slows down the process.

Faster: Method_3 ~ Method_2 ~ method_5, because the logic is very similar, so Spark's catalyst optimizer follows very similar logic with minimal number of operations (get max of a particular column, collect a single-value dataframe); (.asDict() adds a little extra-time comparing 3,2 to 5)

import pandas as pd
import time

time_dict = {}

dfff = self.spark.createDataFrame([(1., 4.), (2., 5.), (3., 6.)], ["A", "B"])
#--  For bigger/realistic dataframe just uncomment the following 3 lines
#lst = list(np.random.normal(0.0, 100.0, 100000))
#pdf = pd.DataFrame({'A': lst, 'B': lst, 'C': lst, 'D': lst})
#dfff = self.sqlContext.createDataFrame(pdf)

tic1 = int(round(time.time() * 1000))
# Method 1: Use describe()
max_val = float(dfff.describe("A").filter("summary = 'max'").select("A").collect()[0].asDict()['A'])
tac1 = int(round(time.time() * 1000))
time_dict['m1']= tac1 - tic1
print (max_val)

tic2 = int(round(time.time() * 1000))
# Method 2: Use SQL
dfff.registerTempTable("df_table")
max_val = self.sqlContext.sql("SELECT MAX(A) as maxval FROM df_table").collect()[0].asDict()['maxval']
tac2 = int(round(time.time() * 1000))
time_dict['m2']= tac2 - tic2
print (max_val)

tic3 = int(round(time.time() * 1000))
# Method 3: Use groupby()
max_val = dfff.groupby().max('A').collect()[0].asDict()['max(A)']
tac3 = int(round(time.time() * 1000))
time_dict['m3']= tac3 - tic3
print (max_val)

tic4 = int(round(time.time() * 1000))
# Method 4: Convert to RDD
max_val = dfff.select("A").rdd.max()[0]
tac4 = int(round(time.time() * 1000))
time_dict['m4']= tac4 - tic4
print (max_val)

tic5 = int(round(time.time() * 1000))
# Method 4: Convert to RDD
max_val = dfff.agg({"A": "max"}).collect()[0][0]
tac5 = int(round(time.time() * 1000))
time_dict['m5']= tac5 - tic5
print (max_val)

print time_dict

Result on an edge-node of a cluster in milliseconds (ms):

small DF (ms) : {'m1': 7096, 'm2': 205, 'm3': 165, 'm4': 211, 'm5': 180}

bigger DF (ms): {'m1': 10260, 'm2': 452, 'm3': 465, 'm4': 916, 'm5': 373}


The below example shows how to get the max value in a Spark dataframe column.

from pyspark.sql.functions import max

df = sql_context.createDataFrame([(1., 4.), (2., 5.), (3., 6.)], ["A", "B"])
df.show()
+---+---+
|  A|  B|
+---+---+
|1.0|4.0|
|2.0|5.0|
|3.0|6.0|
+---+---+

result = df.select([max("A")]).show()
result.show()
+------+
|max(A)|
+------+
|   3.0|
+------+

print result.collect()[0]['max(A)']
3.0

Similarly min, mean, etc. can be calculated as shown below:

from pyspark.sql.functions import mean, min, max

result = df.select([mean("A"), min("A"), max("A")])
result.show()
+------+------+------+
|avg(A)|min(A)|max(A)|
+------+------+------+
|   2.0|   1.0|   3.0|
+------+------+------+

Max value for a particular column of a dataframe can be achieved by using -

your_max_value = df.agg({"your-column": "max"}).collect()[0][0]


I used another solution (by @satprem rath) already present in this chain.

To find the min value of age in the dataframe:

df.agg(min("age")).show()

+--------+
|min(age)|
+--------+
|      29|
+--------+

edit: to add more context.

While the above method printed the result, I faced issues when assigning the result to a variable to reuse later.

Hence, to get only the int value assigned to a variable:

from pyspark.sql.functions import max, min  

maxValueA = df.agg(max("A")).collect()[0][0]
maxValueB = df.agg(max("B")).collect()[0][0]

First add the import line:

from pyspark.sql.functions import min, max

To find the min value of age in the dataframe:

df.agg(min("age")).show()

+--------+
|min(age)|
+--------+
|      29|
+--------+

To find the max value of age in the dataframe:

df.agg(max("age")).show()

+--------+
|max(age)|
+--------+
|      77|
+--------+

Here is a lazy way of doing this, by just doing compute Statistics:

df.write.mode("overwrite").saveAsTable("sampleStats")
Query = "ANALYZE TABLE sampleStats COMPUTE STATISTICS FOR COLUMNS " + ','.join(df.columns)
spark.sql(Query)

df.describe('ColName')

or

spark.sql("Select * from sampleStats").describe('ColName')

or you can open a hive shell and

describe formatted table sampleStats;

You will see the statistics in the properties - min, max, distinct, nulls, etc.


In case some wonders how to do it using Scala (using Spark 2.0.+), here you go:

scala> df.createOrReplaceTempView("TEMP_DF")
scala> val myMax = spark.sql("SELECT MAX(x) as maxval FROM TEMP_DF").
    collect()(0).getInt(0)
scala> print(myMax)
117

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

val testDataFrame = Seq(
  (1.0, 4.0), (2.0, 5.0), (3.0, 6.0)
).toDF("A", "B")

val (maxA, maxB) = testDataFrame.select(max("A"), max("B"))
  .as[(Double, Double)]
  .first()
println(maxA, maxB)

And the result is (3.0,6.0), which is the same to the testDataFrame.agg(max($"A"), max($"B")).collect()(0).However, testDataFrame.agg(max($"A"), max($"B")).collect()(0) returns a List, [3.0,6.0]


Another way of doing it:

df.select(f.max(f.col("A")).alias("MAX")).limit(1).collect()[0].MAX

On my data, I got this benchmarks:

df.select(f.max(f.col("A")).alias("MAX")).limit(1).collect()[0].MAX
CPU times: user 2.31 ms, sys: 3.31 ms, total: 5.62 ms
Wall time: 3.7 s

df.select("A").rdd.max()[0]
CPU times: user 23.2 ms, sys: 13.9 ms, total: 37.1 ms
Wall time: 10.3 s

df.agg({"A": "max"}).collect()[0][0]
CPU times: user 0 ns, sys: 4.77 ms, total: 4.77 ms
Wall time: 3.75 s

All of them give the same answer


I believe the best solution will be using head()

Considering your example:

+---+---+
|  A|  B|
+---+---+
|1.0|4.0|
|2.0|5.0|
|3.0|6.0|
+---+---+

Using agg and max method of python we can get the value as following :

from pyspark.sql.functions import max df.agg(max(df.A)).head()[0]

This will return: 3.0

Make sure you have the correct import:
from pyspark.sql.functions import max The max function we use here is the pySPark sql library function, not the default max function of python.


To just get the value use any of these

  1. df1.agg({"x": "max"}).collect()[0][0]
  2. df1.agg({"x": "max"}).head()[0]
  3. df1.agg({"x": "max"}).first()[0]

Alternatively we could do these for 'min'

from pyspark.sql.functions import min, max
df1.agg(min("id")).collect()[0][0]
df1.agg(min("id")).head()[0]
df1.agg(min("id")).first()[0]

Examples related to python

programming a servo thru a barometer Is there a way to view two blocks of code from the same file simultaneously in Sublime Text? python variable NameError Why my regexp for hyphenated words doesn't work? Comparing a variable with a string python not working when redirecting from bash script is it possible to add colors to python output? Get Public URL for File - Google Cloud Storage - App Engine (Python) Real time face detection OpenCV, Python xlrd.biffh.XLRDError: Excel xlsx file; not supported Could not load dynamic library 'cudart64_101.dll' on tensorflow CPU-only installation

Examples related to apache-spark

Select Specific Columns from Spark DataFrame Select columns in PySpark dataframe What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Spark dataframe: collect () vs select () How does createOrReplaceTempView work in Spark? Spark difference between reduceByKey vs groupByKey vs aggregateByKey vs combineByKey Filter df when values matches part of a string in pyspark Filtering a pyspark dataframe using isin by exclusion Convert date from String to Date format in Dataframes

Examples related to pyspark

Pyspark: Filter dataframe based on multiple conditions How to convert column with string type to int form in pyspark data frame? Select columns in PySpark dataframe How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Filter df when values matches part of a string in pyspark Filtering a pyspark dataframe using isin by exclusion PySpark: withColumn() with two conditions and three outcomes How to get name of dataframe column in pyspark? Spark RDD to DataFrame python PySpark 2.0 The size or shape of a DataFrame

Examples related to apache-spark-sql

Select Specific Columns from Spark DataFrame Pyspark: Filter dataframe based on multiple conditions Select columns in PySpark dataframe What is the difference between spark.sql.shuffle.partitions and spark.default.parallelism? How to find count of Null and Nan values for each column in a PySpark dataframe efficiently? Spark dataframe: collect () vs select () How does createOrReplaceTempView work in Spark? Filter df when values matches part of a string in pyspark Convert date from String to Date format in Dataframes Take n rows from a spark dataframe and pass to toPandas()