0

Best way to get the max value in a Spark dataframe column

This post shows how to run an aggregation (distinct, min, max) on a table something like:

for colName in df.columns:
    dt = cd[[colName]].distinct().count()
    mx = cd.agg({colName: "max"}).collect()[0][0]
    mn = cd.agg({colName: "min"}).collect()[0][0]
    print(colName, dt, mx, mn)

This can be easily done by compute statistics. The stats from Hive and spark are different:

  • Hive gives - distinct, max, min, nulls, length, version
  • Spark Gives - count, mean, stddev, min, max

Looks like there are quite a few statistics that are calculated. How get all of them for all columns using one command?

However, I have 1000s of columns and doing this serially is very slow. Suppose I want to compute some other function say Standard Deviation on each of the columns - how can that be done parallely?

user 923227
  • 1,963
  • 3
  • 20
  • 39
  • 2
    You should just use [`pyspark.sql.DataFrame.describe`](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrame.describe). This computes count, mean, min, max, and stddev for all numerical columns. – pault Sep 12 '18 at 14:04
  • I was looking for that describe! Thanks! – user 923227 Sep 12 '18 at 18:46
  • How to get the `df.describe` for all columns? – user 923227 Sep 12 '18 at 19:04

2 Answers2

2

You can use pyspark.sql.DataFrame.describe() to get aggregate statistics like count, mean, min, max, and standard deviation for all columns where such statistics are applicable. (If you don't pass in any arguments, stats for all columns are returned by default)

df = spark.createDataFrame(
    [(1, "a"),(2, "b"), (3, "a"), (4, None), (None, "c")],["id", "name"]
)
df.describe().show()
#+-------+------------------+----+
#|summary|                id|name|
#+-------+------------------+----+
#|  count|                 4|   4|
#|   mean|               2.5|null|
#| stddev|1.2909944487358056|null|
#|    min|                 1|   a|
#|    max|                 4|   c|
#+-------+------------------+----+

As you can see, these statistics ignore any null values.

If you're using spark version 2.3, there is also pyspark.sql.DataFrame.summary() which supports the following aggregates:

count - mean - stddev - min - max - arbitrary approximate percentiles specified as a percentage (eg, 75%)

df.summary("count", "min", "max").show()
#+-------+------------------+----+
#|summary|                id|name|
#+-------+------------------+----+
#|  count|                 4|   4|
#|    min|                 1|   a|
#|    max|                 4|   c|
#+-------+------------------+----+

If you wanted some other aggregate statistic for all columns, you could also use a list comprehension with pyspark.sql.DataFrame.agg(). For example, if you wanted to replicate what you say Hive gives (distinct, max, min and nulls - I'm not sure what length and version mean):

import pyspark.sql.functions as f
from itertools import chain

agg_distinct = [f.countDistinct(c).alias("distinct_"+c) for c in df.columns]
agg_max = [f.max(c).alias("max_"+c) for c in df.columns]
agg_min = [f.min(c).alias("min_"+c) for c in df.columns]
agg_nulls = [f.count(f.when(f.isnull(c), c)).alias("nulls_"+c) for c in df.columns]

df.agg(
    *(chain.from_iterable([agg_distinct, agg_max, agg_min, agg_nulls]))
).show()
#+-----------+-------------+------+--------+------+--------+--------+----------+
#|distinct_id|distinct_name|max_id|max_name|min_id|min_name|nulls_id|nulls_name|
#+-----------+-------------+------+--------+------+--------+--------+----------+
#|          4|            3|     4|       c|     1|       a|       1|         1|
#+-----------+-------------+------+--------+------+--------+--------+----------+

Though this method will return one row, rather than one row per statistic as describe() and summary() do.

pault
  • 32,557
  • 9
  • 66
  • 110
1

You can put as many expressions into an agg as you want, when you collect they all get computed at once. The result is a single row with all the values. Here's an example:

from pyspark.sql.functions import min, max, countDistinct

r = df.agg(
  min(df.col1).alias("minCol1"),
  max(df.col1).alias("maxCol1"),
  (max(df.col1) - min(df.col1)).alias("diffMinMax"),
  countDistinct(df.col2).alias("distinctItemsInCol2"))
r.printSchema()
# root
#  |-- minCol1: long (nullable = true)
#  |-- maxCol1: long (nullable = true)
#  |-- diffMinMax: long (nullable = true)
#  |-- distinctItemsInCol2: long (nullable = false)

row = r.collect()[0]
print(row.distinctItemsInCol2, row.diffMinMax)
# (10, 9)

You can also use the dictionary syntax here, but it's harder to manage for more complex things.

Bi Rico
  • 23,350
  • 3
  • 45
  • 67
  • When I have 1000s of columns how do I do this? Say we take a few operations like max-min, avg, σ². I am looking for something like iterating over the columns and building the agg - or it is better off building it in sql? – user 923227 Sep 12 '18 at 18:50