I work on a dataframe with two column, mvv and count.
+---+-----+
|mvv|count|
+---+-----+
| 1 | 5 |
| 2 | 9 |
| 3 | 3 |
| 4 | 1 |
i would like to obtain two list containing mvv values and count value. Something like
mvv = [1,2,3,4]
count = [5,9,3,1]
So, I tried the following code: The first line should return a python list of row. I wanted to see the first value:
mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)
But I get an error message with the second line:
AttributeError: getInt
This question is related to
python
apache-spark
pyspark
spark-dataframe
Let's create the dataframe in question
df_test = spark.createDataFrame(
[
(1, 5),
(2, 9),
(3, 3),
(4, 1),
],
['mvv', 'count']
)
df_test.show()
Which gives
+---+-----+
|mvv|count|
+---+-----+
| 1| 5|
| 2| 9|
| 3| 3|
| 4| 1|
+---+-----+
and then apply rdd.flatMap(f).collect() to get the list
test_list = df_test.select("mvv").rdd.flatMap(list).collect()
print(type(test_list))
print(test_list)
which gives
<type 'list'>
[1, 2, 3, 4]
If you get the error below :
AttributeError: 'list' object has no attribute 'collect'
This code will solve your issues :
mvv_list = mvv_count_df.select('mvv').collect()
mvv_array = [int(i.mvv) for i in mvv_list]
Following one liner gives the list you want.
mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()
I ran a benchmarking analysis and list(mvv_count_df.select('mvv').toPandas()['mvv'])
is the fastest method. I'm very surprised.
I ran the different approaches on 100 thousand / 100 million row datasets using a 5 node i3.xlarge cluster (each node has 30.5 GBs of RAM and 4 cores) with Spark 2.4.5. Data was evenly distributed on 20 snappy compressed Parquet files with a single column.
Here's the benchmarking results (runtimes in seconds):
+-------------------------------------------------------------+---------+-------------+
| Code | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect() | 0.4 | 55.3 |
| list(df.select('col_name').toPandas()['col_name']) | 0.4 | 17.5 |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()| 0.9 | 69 |
| [row[0] for row in df.select('col_name').collect()] | 1.0 | OOM |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] | 1.2 | * |
+-------------------------------------------------------------+---------+-------------+
* cancelled after 800 seconds
Golden rules to follow when collecting data on the driver node:
toPandas
was significantly improved in Spark 2.3. It's probably not the best approach if you're using a Spark version earlier than 2.3.
See here for more details / benchmarking results.
The following code will help you
mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()
This will give you all the elements as a list.
mvv_list = list(
mvv_count_df.select('mvv').toPandas()['mvv']
)
On my data I got these benchmarks:
>>> data.select(col).rdd.flatMap(lambda x: x).collect()
0.52 sec
>>> [row[col] for row in data.collect()]
0.271 sec
>>> list(data.select(col).toPandas()[col])
0.427 sec
The result is the same
A possible solution is using the collect_list()
function from pyspark.sql.functions
. This will aggregate all column values into a pyspark array that is converted into a python list when collected:
mvv_list = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0]
Source: Stackoverflow.com