Ingest tables in parallel with an Apache Spark notebook using multithreading

In a few projects I have been asked to ingest a full relational database into a data lake (file storage) every day (or sometimes more frequent). Sometimes tables are big enough to read as delta loads, but often most tables are easiest to do a full overwrite each time. The challenge is if we want to kick off a single Apache Spark notebook to do the job. The simple code to loop through the list of tables ends up running one table after another (sequentially). If none of these tables are very big, it is quicker to have Spark load tables concurrently (in parallel). There are some different options of how to do this, but I am sharing the easiest way I have found when working with a notebook in Databricks, Azure Synapse Spark, Jupyter, or Zeppelin.

You can check out the video to see me walk through it or continue reading for the written steps and code.

Setup code

The first step in the notebook is to set the key variables to connect to a relational database. In this example I use Azure SQL Database other databases can be read using the standard JDBC driver.

If running on Databricks, you should store your secrets in a secret scope so that they are not stored clear text with the notebook. The commands to set db_user and db_password are reading from my secret scope demo for secrets names sql-user-stackoverflow and sql-pwd-stackoverflow.

database = "StackOverflow2010"
db_host_name = ""
db_url = f"jdbc:sqlserver://{db_host_name};databaseName={database}"
db_user = dbutils.secrets.get("demo", "sql-user-stackoverflow") # databricks
db_password = dbutils.secrets.get("demo", "sql-pwd-stackoverflow") #databricks

If running an Azure Synapse notebook, the way you access secrets is using a Key Vault linked service and mssparkutils like the the example below.

database = "StackOverflow2010"
db_host_name = ""
db_url = f"jdbc:sqlserver://{db_host_name};databaseName={database}"
db_user = mssparkutils.credentials.getSecretWithLS("demokv", "sql-user-stackoverflow")
db_password = mssparkutils.credentials.getSecretWithLS("demokv", "sql-pwd-stackoverflow")

Alternatively, you could skip using the secrets and hard code the values if you really want. PLEASE DO NOT DO THIS WITH REAL CODE!

database = "StackOverflow2010"
db_host_name = ""
db_url = f"jdbc:sqlserver://{db_host_name};databaseName={database}"
db_user = 'stackoverflow_reader'
db_password = 'Th1s1sABadIdea!'

Load Table Function

Next you need a function that you will use to load the table. This is the code that tells Spark to read and write a table. You can change the format and option values if you are reading from a different source. This example reads from SQL Server with an optimized driver that may need to be installed.

If called by itself, this will block additional calls and wait until the table load is complete before continuing to process the next commands.

def load_table(table):
    destination_table = "raw_stackoverflow." + table

    df = (
        .option("url", db_url)
        .option("dbtable", table)
        .option("user", db_user)
        .option("password", db_password)


Create Database and Table List

This small snippet of code is to create the database if its the first run and then set a list of tables to read from the source. Sometimes I read the list of tables from the database itself but for this example it is a simple Python list. Be sure to set the LOCATION string to a path that works in your environment. A mounted directory, Azure Storage path, or AWS S3 path are typical options but even a local folder can work if not running in a managed environment.

spark.sql(f"CREATE DATABASE IF NOT EXISTS raw_stackoverflow LOCATION '/demo/raw_stackoverflow'")

table_list = ["Badges", "Comments", "LinkTypes", "PostLinks", "Posts", "PostTypes", "Users", "Votes", "VoteTypes"]

Run function concurrently

Here is the important code with a bit of explanation. First import the libraries and setup a Queue which will hold all the values that need passed to the function that does the work (in our case, load_table). You also define a worker count to limit how many tables will be loaded in parallel. From my experience 2 is fairly low so I usually test out 3 or 4 workers.

from threading import Thread
from queue import Queue

q = Queue()

worker_count = 2

This command will put each table in the queue which enables the code below to share the list of work to do.

for table in table_list:

The run_task function is what controls the work to be done. It will get the next table name from the queue and run the function. The task_done command will mark the task as complete for the queue. You could embed the load_table function call within here but instead I pass in the function so its easy to swap out the logic when needed.

def run_tasks(function, q):
    while not q.empty():
        value = q.get()

The last part of the code creates as many worker threads as requested in the worker_count variable. In the Thread creation the target is the function to run and the args are the values to fill in for the target functions parameters. A variable pointing to the Queue object will always be passed in as an argument. Finally, q.join tells the notebook to wait until all items in the queue have been processed before exiting.

for i in range(worker_count):
    t=Thread(target=run_tasks, args=(load_table, q))
    t.daemon = True


Now you try

That concludes the walk through of loading data in parallel within a single Python notebook. You can find the full notebooks here and try it out in your own environment.

Databricks example notebook

Synapse example notebook

Leave a comment

Leave a Reply