12

I am trying to use a Spark cluster (running on AWS EMR) to link groups of items that have common elements in them. Essentially, I have groups with some elements and if some of the elements are in multiple groups, I want to make one group that contains elements from all of those groups.

I know about GraphX library and I tried to use graphframes package (ConnectedComponents algorithm) to resolve this task, but it seams that the graphframes package is not yet mature enough and is very wasteful with resources... Running it on my data set (cca 60GB) it just runs out of memory no matter how much I tune the Spark parameters, how I partition and re-partion my data or how big cluster I create (the graph IS huge).

So I wrote my own code do accomplish the task. The code works and it solves my problem, but it slows down with every iteration. Since it can take sometimes around 10 iterations to finish, it can run very long and I could not figure out what the problem is.

I start with a table (DataFrame) item_links that has two columns: item and group_name. Items are unique within each group, but not within this table. One item can be in multiple groups. If two items each have a row with the same group name, they both belong to the same group.

I first group by item and find for every item the smallest of all group names from all groups that it belongs to. I append this information as an extra column to the original DataFrame. Then I create a new DataFrame by groupping by the group name and finding the smallest value of this new column within every group. I join this DataFrame with my original table on the group name and replace the group name column with the minimum value from that new column. The idea is, that if a group contains an item that also belongs to some smaller group, this group will be merged it. In every iteration it links groups that were indirectly linked by more and more items in between.

The code that I am running looks like this:

print(" Merging groups that have common items...")

n_partitions = 32

merge_level = 0

min_new_group = "min_new_group_{}".format(merge_level)

# For every item identify the (alphabetically) first group in which this item was found
# and add a new column min_new_group with that information for every item.
first_group = item_links \
                    .groupBy('item') \
                    .agg( min('group_name').alias(min_new_group) ) \
                    .withColumnRenamed('item', 'item_id') \
                    .coalesce(n_partitions) \
                    .cache()

item_links = item_links \
                .join( first_group,
                       item_links['item'] == first_group['item_id'] ) \
                .drop(first_group['item_id']) \
                .coalesce(n_partitions) \
                .cache()

first_group.unpersist()

# In every group find the (alphabetically) smallest min_new_group value.
# If the group contains a item that was in some other group,
# this value will be different than the current group_name.
merged_groups = item_links \
                    .groupBy('group_name') \
                    .agg(
                        min(col(min_new_group)).alias('merged_group')
                    ) \
                    .withColumnRenamed('group_name', 'group_to_merge') \
                    .coalesce(n_partitions) \
                    .cache()

# Replace the group_name column with the lowest group that any of the item in the group had.
item_links = item_links \
                .join( merged_groups,
                       item_links['group_name'] == merged_groups['group_to_merge'] ) \
                .drop(item_links['group_name']) \
                .drop(merged_groups['group_to_merge']) \
                .drop(item_links[min_new_group]) \
                .withColumnRenamed('merged_group', 'group_name') \
                .coalesce(n_partitions) \
                .cache()

# Count the number of common items found
common_items_count = merged_groups.filter(col('merged_group') != col('group_to_merge')).count()

merged_groups.unpersist()

# just some debug output
print("  level {}: found {} common items".format(merge_level, common_items_count))

# As long as the number of groups keep decreasing (groups are merged together), repeat the operation.
while (common_items_count > 0):
    merge_level += 1

    min_new_group = "min_new_group_{}".format(merge_level)

    # for every item find new minimal group...
    first_group = item_links \
                        .groupBy('item') \
                        .agg(
                            min(col('group_name')).alias(min_new_group)
                        ) \
                        .withColumnRenamed('item', 'item_id') \
                        .coalesce(n_partitions) \
                        .cache() 

    item_links = item_links \
                    .join( first_group,
                           item_links['item'] == first_group['item_id'] ) \
                    .drop(first_group['item']) \
                    .coalesce(n_partitions) \
                    .cache()

    first_group.unpersist()

    # find groups that have items from other groups...
    merged_groups = item_links \
                        .groupBy(col('group_name')) \
                        .agg(
                            min(col(min_new_group)).alias('merged_group')
                        ) \
                        .withColumnRenamed('group_name', 'group_to_merge') \
                        .coalesce(n_partitions) \
                        .cache()

    # merge the groups with items from other groups...
    item_links = item_links \
                    .join( merged_groups,
                           item_links['group_name'] == merged_groups['group_to_merge'] ) \
                    .drop(item_links['group_name']) \
                    .drop(merged_groups['group_to_merge']) \
                    .drop(item_links[min_new_group]) \
                    .withColumnRenamed('merged_group', 'group_name') \
                    .coalesce(n_partitions) \
                    .cache()

    common_items_count = merged_groups.filter(col('merged_group') != col('group_to_merge')).count()

    merged_groups.unpersist()

    print("  level {}: found {} common items".format(merge_level, common_items_count))

As I said, it works, but the problem is, that it slows down with every iteration. The iterations 1-3 run just a few seconds or minutes. Iteration 5 runs around 20-40 minutes. Iteration 6 sometimes doesn't even finish, because controller runs out of memory (14 GB for controller, around 140 GB of RAM for the entire cluster with 20 CPU cores... the test data is around 30 GB).

When I monitor the cluster in Ganglia, I see, that after every iteration the workers perform less and less work and the controller performs more and more. The network traffic also goes down to zero. Memory usage is rather stable after the initial phase.

I read lot about re-partitioning, turning Spark parameters and background of shuffle operations and I did my best to optimize everything, but I have no idea what's going on here. Below is a load of my cluster nodes (yellow for controller node) over time as the code above is running.

load of cluster nodes, yellow is controller

grepe
  • 1,569
  • 1
  • 12
  • 23
  • I'm interested in generating some test data to use your example for some Spark benchmarking. Would you be able to share some more statistics about the data you're working with? (e.g. number of items, number of groups, and distribution of items to groups) Thanks! – Michael Mior Feb 20 '18 at 17:02

4 Answers4

4

A simple reproduction scenario:

import time
from pyspark import SparkContext

sc = SparkContext()

def push_and_pop(rdd):
    # two transformations: moves the head element to the tail
    first = rdd.first()
    return rdd.filter(
        lambda obj: obj != first
    ).union(
        sc.parallelize([first])
    )

def serialize_and_deserialize(rdd):
    # perform a collect() action to evaluate the rdd and create a new instance
    return sc.parallelize(rdd.collect())

def do_test(serialize=False):
    rdd = sc.parallelize(range(1000))
    for i in xrange(25):
        t0 = time.time()
        rdd = push_and_pop(rdd)
        if serialize:
            rdd = serialize_and_deserialize(rdd)
        print "%.3f" % (time.time() - t0)

do_test()

Shows major slowdown during the 25 iterations:

0.597 0.117 0.186 0.234 0.288 0.309 0.386 0.439 0.507 0.529 0.553 0.586 0.710 0.728 0.779 0.896 0.866 0.881 0.956 1.049 1.069 1.061 1.149 1.189 1.201

(first iteration is relatively slow because of initialization effects, second iteration is quick, every subsequent iteration is slower).

The cause seems to be the growing chain of lazy transformations. We can test the hypothesis by rolling up the RDD using an action.

do_test(True)

0.897 0.256 0.233 0.229 0.220 0.238 0.234 0.252 0.240 0.267 0.260 0.250 0.244 0.266 0.295 0.464 0.292 0.348 0.320 0.258 0.250 0.201 0.197 0.243 0.230

The collect(), parallelize() adds about 0.1 second to each iteration, but completely eliminates the incremental slowdown.

Freek Wiekmeijer
  • 3,395
  • 25
  • 33
  • interesting. i did try to force the DF evaluation by adding ".count()" after each iteration. it helped me in some other cases, but not in the code described above. – grepe Mar 06 '17 at 13:38
  • 2
    When you do a `count()`, you do perform an action. This enforces evaluation of the lazy transformations. But because you're not overwriting the `rdd` reference variable in the iteration, that one is still stored like before the `count()`, i.e. `original rdd.filter().union().filter().union() ... `. – Freek Wiekmeijer Mar 06 '17 at 14:20
3

I resolved this issue by saving the DataFrame to HDFS at the end of every iteration and reading it back from HDFS in the beginning of the next one.

Since I do that, the program runs as a breeze and doesn't show any signs of slowing down, overusing the memory or overloading the driver.

I still don't understand why this happens, so I'm leaving the question open.

grepe
  • 1,569
  • 1
  • 12
  • 23
  • 1
    Can you add the code (the functions you used for it), please? – Maria May 09 '18 at 16:43
  • Confirmed this fixed my issue. I just write and then read back the result every 100 iterations. It basically forces a synchronization. Thanks for the answer. – Nicholas Leonard Jan 14 '21 at 20:00
0

Your code has the correct logic. It is just that you never call item_links.unpersist() so firstly it slows down (trying to do swapping with local disk) then OOM.

Memory usage in Ganglia may be misleading. It won't change since memory is allocated to executors at the start of the script, regardless if they use it or not later. You may check Spark UI for storage / executor status.

shuaiyuancn
  • 2,404
  • 2
  • 20
  • 31
  • I just tried unpersisting dataframes within the loop and it did not really help. the issue is suspiciously similar to the situation described here: http://stackoverflow.com/questions/31659404/spark-iteration-time-increasing-exponentially-when-using-join?rq=1 except that I am keeping partitions at sane level both with spark parameters (`spark.default.parallelism` and `spark.sql.shuffle.partitions`) and doing the coalesce even manually, just to be sure. any other ideas? – grepe Aug 24 '16 at 09:04
  • You have been controlling the partitions so that should be the cause. Given that a few worker nodes were still running, I still suspect this is a memory issue. What happens if you remove all caching (or try to persist to mem AND disk)? Excessive caching (even if you unpersist it later) could greatly impact the performance especially when data > memory. – shuaiyuancn Aug 24 '16 at 10:09
  • When I remove the .coalesce().cache() I get a significant slow down. I will try to do .persist(disk) instead of cache and see what happens. – grepe Aug 24 '16 at 10:38
  • UPDATE: no luck with changing the persist type – grepe Aug 24 '16 at 11:03
0

Try printing dataFrame.explain to see the logical plan. Every iteration the transformations on this Dataframe keeps on adding up to the logical plan, and so the evaluation time keeps on adding up.

You can use below solution as a workaround :

dataFRame.rdd.localCheckpoint()

This writes the RDDs for this DataFrame to memory and removes the lineages , and then created the RDD from the data written to the memory.

Good thing about this is that you dont need to write your RDD to HDFS or disk. However, this also brings some issues with it, which may or may not effect you. You can read the documentation of "localCheckPointing" method for details.

Amanpreet Khurana
  • 479
  • 1
  • 5
  • 14