recursive cte in spark SQL

18.7k Views Asked by At
; WITH  Hierarchy as 
        (
            select distinct PersonnelNumber
            , Email
            , ManagerEmail 
            from dimstage
            union all
            select e.PersonnelNumber
            , e.Email           
            , e.ManagerEmail 
            from dimstage  e
            join Hierarchy as  h on e.Email = h.ManagerEmail
        )
        select * from Hierarchy

Can you help achieve the same in SPARK SQL

4

There are 4 best solutions below

3
On

This is not possible using SPARK SQL. The WITH clause exists, but not for CONNECT BY like in, say, ORACLE, or recursion in DB2.

0
On

The Spark documentation provides a "CTE in CTE definition". This is reproduced below:

-- CTE in CTE definition
WITH t AS (
    WITH t2 AS (SELECT 1)
    SELECT * FROM t2
)
SELECT * FROM t;
+---+
|  1|
+---+
|  1|
+---+

You can extend this to multiple nested queries, but the syntax can quickly become awkward. My suggestion is to use comments to make it clear where the next select statement is pulling from. Essentially, start with the first query and place additional CTE statements above and below as needed:

WITH t3 AS (
WITH t2 AS (
WITH t1 AS (SELECT distinct b.col1
            FROM data_a as a, data_b as b
            WHERE a.col2 = b.col2
            AND a.col3 = b.col3
-- select from t1
            )
            SELECT distinct b.col1, b.col2, b.col3
            FROM t1 as a, data_b as b
            WHERE a.col1 = b.col1
-- select from t2
            )
            SELECT distinct b.col1
            FROM t2 as a, data_b as b
            WHERE a.col2 = b.col2
            AND a.col3 = b.col3
-- select from t3
            )
            SELECT distinct b.col1, b.col2, b.col3
            FROM t3 as a, data_b as b
            WHERE a.col1 = b.col1;
0
On

This is quite late, but today I tried to implement the cte recursive query using PySpark SQL.

Here, I have this simple dataframe. What I want to do is to find the NEWEST ID of each ID.

The original dataframe:

+-----+-----+
|OldID|NewID|
+-----+-----+
|    1|    2|
|    2|    3|
|    3|    4|
|    4|    5|
|    6|    7|
|    7|    8|
|    9|   10|
+-----+-----+

The result I want:

+-----+-----+
|OldID|NewID|
+-----+-----+
|    1|    5|
|    2|    5|
|    3|    5|
|    4|    5|
|    6|    8|
|    7|    8|
|    9|   10|
+-----+-----+

Here is my code:

df = sqlContext.createDataFrame([(1, 2), (2, 3), (3, 4), (4, 5), (6, 7), (7, 8),(9, 10)], "OldID integer,NewID integer").checkpoint().cache()

dfcheck = df.drop('NewID')
dfdistinctID = df.select('NewID').distinct()
dfidfinal = dfdistinctID.join(dfcheck, [dfcheck.OldID == dfdistinctID.NewID], how="left_anti") #We find the IDs that have not been replaced

dfcurrent = df.join(dfidfinal, [dfidfinal.NewID == df.NewID], how="left_semi").checkpoint().cache() #We find the the rows that are related to the IDs that have not been replaced, then assign them to the dfcurrent dataframe.
dfresult = dfcurrent
dfdifferentalias = df.select(df.OldID.alias('id1'), df.NewID.alias('id2')).checkpoint().cache()

while dfcurrent.count() > 0:
  dfcurrent = dfcurrent.join(broadcast(dfdifferentalias), [dfcurrent.OldID == dfdifferentalias.id2], how="inner").select(dfdifferentalias.id1.alias('OldID'), dfcurrent.NewID.alias('NewID')).cache()
  dfresult = dfresult.unionAll(dfcurrent)

display(dfresult.orderBy('OldID'))

Databricks notebook screenshot

I know that the performance is quite bad, but at least, it give the answer I need.

This is the first time that I post an answer to StackOverFlow, so forgive me if I made any mistake.

0
On

You can recursively use createOrReplaceTempView to build a recursive query. It's not going to be fast, nor pretty, but it works. Following @Pblade's example, PySpark:

def recursively_resolve(df):
    rec = df.withColumn('level', F.lit(0))
    
    sql = """
        select this.oldid
             , coalesce(next.newid, this.newid) as newid
             , this.level + case when next.newid is not null then 1 else 0 end as level
             , next.newid is not null as is_resolved
          from rec this
          left outer
          join rec next
            on next.oldid = this.newid
    """
    find_next = True
    while find_next:
        rec.createOrReplaceTempView("rec")
        rec = spark.sql(sql)
        # check if any rows resolved in this iteration
        # go deeper if they did
        find_next = rec.selectExpr("ANY(is_resolved = True)").collect()[0][0]
        
    return rec.drop('is_resolved')
        

Then:

src = spark.createDataFrame([(1, 2), (2, 3), (3, 4), (4, 5), (6, 7), (7, 8),(9, 10)], "OldID integer,NewID integer")
result = recursively_resolve(src)
result.show()

Prints:

+-----+-----+-----+
|oldid|newid|level|
+-----+-----+-----+
|    2|    5|    2|
|    4|    5|    0|
|    3|    5|    1|
|    7|    8|    0|
|    6|    8|    1|
|    9|   10|    0|
|    1|    5|    2|
+-----+-----+-----+