I have a tiny pyspark dataframe with relations and a function that calculates the transitive closure. I know already a couple of ways in which I can improve the function (including getting rid of the groupBy
), but let's stick with this. When I iteratively apply the closure function, the computation time increases exponentially and spark even runs out of heap memory.
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import col, min
def closure(eq: DataFrame) -> DataFrame:
eqrev = eq.select(col("ID2").alias("ID1"), col("ID1").alias("ID2"))
bi = eq.union(eqrev).distinct().cache()
oldCount = 0
nextCount = bi.count()
while True:
oldCount = nextCount
newEdges = bi.alias("b1").join(bi.alias("b2"), col("b1.ID1") == col("b2.ID1")).select(col("b1.ID2").alias("ID1"), col("b2.ID2"))
bi = bi.union(newEdges).distinct().cache()
nextCount = bi.count()
if nextCount == oldCount:
break
return bi.alias("b1").filter(col("b1.ID1") > col("b1.ID2")).groupBy("ID1").agg(min("ID2").alias("ID2")).cache()
b0 = sqlContext.createDataFrame([[ 22, 18 ], [ 20, 15] , [ 25, 26], [ 25, 29 ]], [ "ID1", "ID2" ])
b1 = closure(b0)
display(b1)
b2 = closure(b1)
display(b2)
b3 = closure(b2)
display(b3)
b4 = closure(b3)
b1
, b2
, b3
all have 4 rows and 200 partitions (which are introduced by the join
). The execution plan grows linearly: for b4
it is 13 stages.
On my small cluster, the computation of b2
takes 8 seconds, b3
takes 40 seconds and b4
gives a java.lang.OutOfMemoryError: Java heap space
after a few minutes.
I would have expected that, given that I'm caching the result of each closure, the spark engine would be able to work this out.
Some related articles:
Spark iteration time increasing exponentially when using join : the approved answer there is saying that the number of partitions is growing exponentially. But this is not the case for me. It stays at 200.
spark out of memory multiple iterations : it is suggested to use
localCheckpoint
If I change .cache()
in the last line of the function to .localCheckpoint()
, I do not get increasing execution time nor an out-of-memory exception. The documentation of `localCheckPoint) says: Checkpointing can be used to truncate the logical plan of this DataFrame, which is especially useful in iterative algorithms where the plan may grow exponentially. Local checkpoints are stored in the executors using the caching subsystem and therefore they are not reliable.
I now have the following questions:
I'm already running into trouble with 4 iterations with almost no data. Is that really to be expected?
Why is the computation time increasing so rapidly and why is the engine running out of heap space? The execution plan still fits on my screen.
What are the repercussions of using
localCheckPoint
in case of failures?