1

There are other thread on how to rename columns in a PySpark DataFrame, see here, here and here. I don't think the existing solutions are sufficiently performant or generic (I have a solution that should be better and I'm stuck on an edge case bug). Let's start by reviewing the issues with the current solutions:

  • Calling withColumnRenamed repeatedly will probably have the same performance problems as calling withColumn a lot, as outlined in this blog post. See Option 2 in this answer.
  • The toDF approach relies on schema inference and does not necessarily retain the nullable property of columns (toDF should be avoided in production code). I'm guessing this approach is slow as well.
  • This approach is close, but it's not generic enough and would be way too much manual work for a lot of columns (e.g. if you're trying to convert 2,000 column names to snake_case)

I created a function that's generic and works for all column types, except for column names that include dots:

import pyspark.sql.functions as F

def with_columns_renamed(fun):
    def _(df):
        cols = list(map(
            lambda col_name: F.col(col_name).alias(fun(col_name)),
            df.columns
        ))
        return df.select(*cols)
    return _

Suppose you have the following DataFrame:

+-------------+-----------+
|i like cheese|yummy stuff|
+-------------+-----------+
|         jose|          a|
|           li|          b|
|          sam|          c|
+-------------+-----------+

Here's how to replace all the whitespaces in the column names with underscores:

def spaces_to_underscores(s):
    return s.replace(" ", "_")

df.transform(with_columns_renamed(spaces_to_underscores)).show()
+-------------+-----------+
|i_like_cheese|yummy_stuff|
+-------------+-----------+
|         jose|          a|
|           li|          b|
|          sam|          c|
+-------------+-----------+

The solution works perfectly, except for when the column name contains dots.

Suppose you have this DataFrame:

+-------------+-----------+
|i.like.cheese|yummy.stuff|
+-------------+-----------+
|         jose|          a|
|           li|          b|
|          sam|          c|
+-------------+-----------+

This code will error out:

def dots_to_underscores(s):
    return s.replace(".", "_")

df.transform(quinn.with_columns_renamed(dots_to_underscores))

Here's the error message: pyspark.sql.utils.AnalysisException: "cannot resolve 'i.like.cheese' given input columns: [i.like.cheese, yummy.stuff];;\n'Project ['i.like.cheese AS i_like_cheese#242, 'yummy.stuff AS yummy_stuff#243]\n+- LogicalRDD [i.like.cheese#231, yummy.stuff#232], false\n"

How can I modify this solution to work for column names that have dots? I'm also assuming that the Catalyst optimizer will have the same optimization problems for multiple withColumnRenamed calls as it does for multiple withColumn calls. Let me know if Catalyst handles multiple withColumnRenamed calls better for some reason.

Powers
  • 12,561
  • 7
  • 60
  • 82
  • @murtihash - Using `reduce` and `withColumnRenamed` works, but I'm specifically trying to avoid that approach for [the reasons outlined in this post](https://medium.com/@manuzhang/the-hidden-cost-of-spark-withcolumn-8ffea517c015). The escaping approach sounds promising, please feel free to write an answer with a working code snippet. – Powers Jul 16 '20 at 17:57
  • i see, just read that post, thats interesting that so many withColumn calls could lead to a bottleneck. you could stick to escaping with ` in map with select – murtihash Jul 16 '20 at 18:01

2 Answers2

1

Try escaping using ` :

import pyspark.sql.functions as F

def with_columns_renamed(fun):
    def _(df):
        cols = list(map(
            lambda col_name: F.col("`{0}`".format(col_name)).alias(fun(col_name)),
            df.columns
        ))
        return df.select(*cols)
    return _

Or use withColumnRenamed with reduce.

from functools import reduce

reduce(lambda new_df, col: new_df.withColumnRenamed(col,col.replace('.','_')),df.columns,df)
murtihash
  • 6,833
  • 1
  • 5
  • 20
1

You could do something simple like this,

import pyspark.sql.functions as F

def with_columns_renamed(fun):
    def _(df):
        cols = list(map(
            lambda col_name: F.col('`' + col_name + '`').alias(fun(col_name)),
            df.columns
        ))
        return df.select(*cols)
    return _
Kirubakar
  • 31
  • 1