How to apply" Initcap" only on records whose values are not all capital letters in a PySpark DataFrame?

220 Views Asked by At

I have a PySpark DataFrame and I want to apply "Initcap" on a specific column. However, I want this transformation only on records whose value is not all capitals. For example ,in the sample dataset below, I don't want to apply "Initcap" on USA:

# Prepare Data
data = [(1, "Italy"), \
        (2, "italy"), \
        (3, "USA"), \
        (4, "China"), \
        (5, "china")
  ]
 
# Create DataFrame
columns= ["ID", "Country"]
df = spark.createDataFrame(data = data, schema = columns)
df.show(truncate=False)

enter image description here

The expected output will be:

ID Country
1 'Italy'
2 'Italy'
3 'USA'
4 'China'
5 'China'
3

There are 3 best solutions below

0
wwnde On BEST ANSWER
df.withColumn('Country',when(df.Country==upper(df.Country),df.Country).otherwise( initcap('Country'))).show(truncate=False)

+---+-------+
|ID |Country|
+---+-------+
|1  |Italy  |
|2  |Italy  |
|3  |USA    |
|4  |China  |
|5  |China  |
+---+-------+
0
Muhammad Ali On

I hope it works for your solution, you need to define a custom UDF function for this,

import findspark
findspark.init()
findspark.find()
from pyspark.sql import *
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
# findspark.find()

# create a Spark Session
spark = SparkSession.builder.appName('StackOverflowInitCap').getOrCreate()

data = [(1, "Italy"), \
        (2, "italy"), \
        (3, "USA"), \
        (4, "China"), \
        (5, "china")
  ]
 
# Create DataFrame
columns= ["ID", "Country"]
df = spark.createDataFrame(data = data, schema = columns)
def initCap(x):
    if x[0].islower():
        return x[0].upper() + x[1:]
    else:
        return x
upperCaseUDF = udf(lambda x:initCap(x),StringType())
df.withColumn("Country", upperCaseUDF(df.Country)) \
  .show(truncate=False)
0
Rakesh Chintha On

You can try the below code in plain Hive/SQL expression:

import pyspark.sql.functions as sf

df.withColumn("Country", sf.expr("""
case
  when Country == upper(country) then Country
  else initcap(Country)
end
""")