Merging and Cleanup Duplicate Columns

Before joining data frames, making sure no columns are repeated between tables

If you have the same column name in two tables, after joining you'll have two columns of the same name in the dataframe. This will cause disambiguity errors when trying to select or operate on one or the both of them. Dropping the column will drop all that have the same name. Thus, make sure you don't have any columns of the same name between any of your tables.

In my case, I had over 14 tables to join, each having more than 200 columns, there's no way to do that manaually. In the code below, I show a programatic way; though it's fragile and may not execute property to each case, it works. You do need to verify and tweak it according to your needs.

Imports and setup

import pyspark.sql.functions as f
from pyspark.sql.types import *
from pyspark.sql.window import Window
from pyspark.ml.feature import VectorAssembler, StringIndexer #, OneHotEncoder

import pandas as pd
import numpy as np
from collections import Counter

Some needed global variables

RANDOM_SEED_VALUE=42
TARGET_VARIABLE= "label_column"
ID_COLUMN= "id_column"
TIMESTAMP_COLUMN= "timestamp_column" 

The timestamp column is used for ordering, the id column is used for partionining, and the target variable used for correlations and predictions. I will used the same random seed througout the project notebooks. The columns above, target, id, timestamp, should also be omitted of the predictors set when you're predicting. Singling them out help keep in mind to omit them before predicting.

Exploring files in a blob storage path. This cell is DataBricks specific

def files_in_path(path):
  lst= [str(item).split("name=")[1].split(",")[0].split("'")[1].strip("/") for item in dbutils.fs.ls(path)]
	return lst

This function returns a list of the folders names only, whereas dbutils.fs.ls returns a list of the folder names, and other information about it.

Reading in tables, and creating lists

In the next step, read in the parquet files for your wanted tables, in my case, there were 14 of them. table1= spark.parquet.read("path_to_table1") etc. Then I'm going to create a list of these tables, for ease of looping through for repeated procedures. I'm also going to create a string list of the corresponding tables names, for ease of tracking and knowing what is where.

tables_list= [table1, table2, ... , table_n]
tables_list_names= ["table1", "table2", ... , "table_n"]

In my case, it's probably commonplace, that you will have some id columns exists in some tables but not others, and other tables have other id columns to merge on. So you need to locate the shared id columns to merge tables on. First, list all the id columns you have, can be 2 or more. known_id_columns=["idcol1", "idcol2"]

Adding counts of a collection of columns, if needed

Sometimes, you will have a repeated quantity with different values. For example, diagnosis or procedure codes per patient, purchases per customer, etc. It's a powerful statistical feature to find how many we have of those columns. Next snippet counts them.

def count_codes(df1, colfamily:str):
  """`colfamily` is one of the column names you want to count e.g. diagnosis code columns, or items purchases columns """
	
	# step1 - binarize each column like so: if it's null -> 0, if there's a value -> 1 regardless what that value is, we're just counting existance of a value here
	colfamily_columns=[k for k in df1.columns if (k.startswith("purchased")) and (k.endswith("place") is False)] # or whatever conditions you need to grab your columns you want to count
	for item in colfamily_columns:
		df1= df1\
		.withColumn(f"{item}_INT", 
			f.when(f.length(f.col(item))>1, 1).otherwise(0)) # in my case it wasn't nulls, there was a space character, looks empty but it's not. If you have nulls, use condition `f.col(item).isNull()`
	
	# step 2 - create an array column of the 1/0 values created above
	df1= df1\
	.withColumn(colfamily, 
	f.array(*[k for k in df1.columns if (colfamily in k) and (k.endwith("_INT")) ]))
	
	# step 3 - sum the 1/0 array column, thus find count
	cond0= f"AGGREGATE({colfamily}, 0, (x,y) -> x+y)" #this is a SQL expression
	
	df1= df1\
	.withColumn(f"{colfamily}_count", f.expr(cond0))
	
	# step 4 - clean up the columns created in the process, the integers and the array column, keeping the originals in case we need them later
	df1= df1\
	.drop(*[item for item in df1.columns if item.endswith("_INT")], colfamily)
	# you can also drop the originals in this step should you choose to
	
	return df1

Now simply apply it for however many families of columns you need to count, I had more than 2 in my case. table1= count_codes(df1= table1, colfamily="diag_cd")

IMPORTANT: in my case, the family of columns were all diagnosis codes, that started with "diag_cd" but took values between 2 and 25, so they were like "diag_cd2", "diag_cd3", etc. Whereas the other family I needed to count started with "icd_proc_" because they were of the format "icd_proc_02" etc. I made the code to capture the starting phrase of the family of columns I wanted to count. You modify your code as you see fit.

Table explorations needed before joining

Things to consider: If you don't need all the columns from all the tables, in this step you reassign the tables variable names with the new columns selections table1= table1.select(...) Remember to always select the id columns and the timestamp from the tables as well!

This step is needed exploration for joining the tables; because tables with more count, may have more of the ID column and thus more information we don't want to lose. Other things to consider, refresh rate for each table, that is, when was the table last updated? and how often it gets updated? While you can ask your data stewards or engineers about that, don't take anyone's word for it, do the exploration yourself, and it's a simple query. See details below, under "Explorations before joinig" paragraph.

Sorting tables w.r.t. row count

So we can use the appropriate type of join to favor the most important one in your opinion; or to favor the bigger, and more recently updated table. This is going to be partially manual step, you put them in a list from biggest and most updated, to least updated and smallest.

This snippet to get counts, and sort them in descendant order, such that, all being equal in refresh rate, we can always use "left" join to favor the bigger table.

tables_count_dictionary= {k:v for k,v in zip(tables_list_names, [2758781, 3104096524, ...])}

ordered_tables_list_names= sorted(tables_count_dictionary, key=tables_count_dictionary.get, reverse=True)
# source: https://stackoverflow.com/a/3177911/11381214

ordered_tables_list= [table11, table2, table3, ...]

NOTES

  • [2758781, 3104096524, ...] this is a list of the row counts each table in tables_list, preserving order. We get this with table.count() command. This can be computationally expensive, so it's up to you whether you wanna break it down and do it one at a time, or decide you can handle it in a for-loop in one cell, which might break, depending on your cluster capabilities. I usually prefer to break down heavy computations as much as possible to keep things flowing and avoid bottle necks. I noticed that after a particularly big query that pushes the cluster limits, it become bogged down and not as responsive or fast in executing the next commands, even on a good cluster mangement software like DataBricks.

  • ordered_tables_list is the same as ordered_tables_list_names but as variables, not as text. This is a manual step for me; copy, paste, and remove quotations, since I kept the variable names same as their string names.

Rename shared ID columns to avoid confusion after merging

So I don't end up with duplicated id columns in the same dataframe, Spark won't distinguish them and I can't resolve that then.

renamed_ids_tables=[]
for n,item in enumerate(ordered_tables_list):
	temp= item # so it doesn't change the original one in list
	for k in item.columns:
		if k in known_id_columns: 
			temp= temp.withColumnRenamed(k, "{}_LOC{}".format(k,n))
		else:
			pass
	renamed_ids_tables.append(temp)

print(len(renamed_ids_tables)) # should be same as how many tables you have

Cleaning up repeated columns between tables, to avoid trouble after merging

The idea is,

  1. start with the first table columns list

  2. in the loop: a. get the second table columns b. IF any of the new table columns are NOT an id column, AND NOT in current existing list of previous tables, keep it; otherwise, drop it. c. add set of new columns to the bag of columns, and repeat

columns_bag= set(renamed_ids_tables[0].columns)
clean_tables=[renamed_ids_tables[0]]
for n,item in enumerate(renamed_ids_tables[1:]):
	temporary_columns= [col for col in item.columns if (col not in columns_bag) and (col not in known_id_columns)]
	columns_bag= columns_bag.union(set(temporary_columns))
	clean_tables.append(item.select(*set(temporary_columns)))
	
print(len(columns_bag))

IMPORTANT: Note how this snippet is ran only after we renamed the id columns according to each table's place in the orderd list of tables. If we don't rename the id columns first, we will lose them in all tables but the first, meaning we can't merge tables.

Another way to de-duplicate columns, but less efficient,

columns_bag= renamed_ids_tables[0].columns
clean_tables=[renamed_ids_tables[0]]
for n,item in enumerate(renamed_ids_tables[1:]):
	temporary_columns= [col for col in item.columns if (col not in columns_bag) and (col not in known_id_columns)]
	for col in temporary_columns:
		columns_bag.append(col)
	clean_tables.append(item.select(*temporary_columns))
	
# test correct execution, 
print(len(columns_bag), len(set(columns_bag)))

Joining tables on shared id columns

In my usecase, I had the first four tables in the ordered list, of the same refresh rate, monthly, and of the same size, each contained a different set of columns, some are shared, and had one of the id columns in between them. Only one of them would have an id column to join some of the other tables.

Joining first four tables in the ordered list on same id column, with inner join because of same refresh rate and same size

first_merge= clean_tables[0]\
.join(clean_tables[1], 
	on= clean_tables[0].idcol1_LOC0 == clean_tables[1].idcol1_LOC1, 
	how= "inner")\
.drop("idcol1_lOC1")\
join(clean_tables[2], 
	on= clean_tables[0].idcol1_LOC0 == clean_tables[2].idcol1_LOC2, 
	how= "inner")\
.drop("idcol1_lOC2")\
join(clean_tables[3], 
	on= clean_tables[0].idcol1_LOC0 == clean_tables[3].idcol1_LOC3, 
	how= "inner")\
.drop("idcol1_lOC3")\

Joining another table to this merged one, on a different id column

second_merge= first_merge\
.join(clean_tables[12], 
on= clean_tables[12].idcol3_LOC12 == first_merge.idcol3_LOC0, 
how= "left")\
.drop("idcol3_LOC12")

clean_tables[12] is far down the ordered list, this is a table that is small and isn't refreshed frequently. Thus I favored the former with a "left" join.

Joining the rest of the tables on yet another id column "idcol2" Now I have the rest of the tables all have the same id column, and it exists in the big merged table now.

In order to do that, and not repeat codes for 9 tables, I need to create a shorted list of the previously ordered and cleaned list of tables; one that doesn't contain the tables I joined so far. Here the tables I need are in locations 4 through 11 and the 13th one.

to_join= []
for item in clean_tables[4:12]:
	to_join.append(item)
to_join.append(clean_tables[13])

Also to not repeat myself while joining, and because I already have a copy of the id column with the suffix "_LOC0" in the merged table, I'm going to rename all the previuosly renamed idcol2 with suffix "_LOC" to one without the suffix so I put the joins in a loop. Also, because the tables are already ordered, in a descendant fashion, I'm using "left" joins from one to the next

join_ready_tables=[]
for _,table in enumerate(to_join):
	for col in table.columns:
		if col.startswith("idcol2"):
			temp= table.withColumnRenamed(col, "idcol2")
		else:
			pass
	join_ready_tables.append(temp)
	
	
merged= second_merge
for table in join_ready_tables:
	merged= merged\
	.join(table, 
		on= second_merge.idcol2_LOC0 == table.idcol2, 
		how="left")\
	.drop(table.idcol2)
		
# verification
print(len(merged.columns), len(set(merged.columns)))

Joining cleanup

To see the id colums leftovers, print(sorted([item for item in merged.columns if "_LOC" in item])) We only need one copy of each, so delete the rest manually merged= merged.drop("idcol3_LOC9", "idcol2_LOC7")

Undo the id columns renaming, now that there's no duplicated columns anywhere

To get rid of the "_LOC" suffix

for col in [item for item in merged.columns if "_LOC" in item]:
	merged= merged.withColumnRenamed(col, col[:-5])

Explorations before joining

Seeing ID columns in each table

Like I said, I had more than 2 id columns between all the tables, I had more than 10 actually, and it became useful to see which tables have which columns. Snippet below reveal that in a Pandas dataframe

# create a table of zeros, fill in the id columns that exist in each table with 1
id_df= pd.DataFrame(np.zeros((len(tables_list), len(known_id_columns))), 
columns=known_id_columns, 
index= tables_list_names)
# you could instead create a table of nulls if you like, 
# id_df= pd.DataFrame(columns=known_id_columns, index=tables_list_names)

for idcol in known_id_columns:
	for n, table in enumerate(tables_list):
		if idcol in table.columns:
			id_df[idcol][tables_list_names[n]] = 1
		else:
			pass
id_df.astype('int32') # just b/c it's easier to see than floats of ones and zeros. This will also display the dataframe

You'll have something like this,

Then you can locate which tables have a particular id column easily, id_df[id_df["idcol1"]==1].astype('int32')

Columns overlap between tables

When you merge tables from a SQL database, some of these tables might have more than id columns in common. This duplication of columns will cause you a major headache after joining; because Spark will not distinguish between the same column name across tables, thus trying to select a column will through an "ambiguous" error, and dropping the duplicated column will drop all of them and you won't have it anymore. In the code above, I showed you how to tackle this by de-duplicating the common columns between tables, and renaming id columns that you need to join on. Here, I show you more or less an exploratory step to see how many columns you have more than 1 of in each table. This might not be very usefull if you have too many columns, because you'll end up with a too big of a dataframe to visually scan, but you can still do more statistics on it and get an idea.

# create a columns bag, that has all the columns from all the tables
columns_bag= set(tables_list[0].columns)
for n,item in enumerate(tables_list):
	temporary_columns= [col for col in item.columns if (col not in columns_bag) and (col not in known_id_columns)]
	columns_bag= columns_bag.union(set(temporary_columns))
	
# get the overlap datafram
overlap_df= pd.DataFrame(np.zeros( (len(tables_list), len(columns_bag)) ), 
columns= columns_bag, index=tables_list_names)

for n,item in enumerate(tables_list):
	for col in columns_bag:
		if col in item.columns:
			overlap_df[col][tables_list_names[n]] = 1
		else:
			pass
			
overlap_df #to print it out

And if it's too big, here's the aggregated version of it, to see which columns have more than one occurance,

aggd_overlap= overlap_df.agg(['sum'], axis=0).T #take the transpose 

repeated_columns= aggd_overlap[aggd_overlap['sum'] > 1].index_to_list()

print(aggd_overlap['sum'].max())
aggd_overlap[aggd_overlap['sum']>1]

# get the summary
overlap_summary= pd.DataFrame(np.zeros( (len(tables_list), len(repeated_columns)) ),
columns= repeated_columns, 
index=tables_list_names)

for n,item in enumerate(tables_list):
	for col in repeated_columns:
		if col in item.columns:
			overlap_summary[col][tables_list_names[n]] == 1
		else:
			pass
			
overlap_summary.astype("int32")

Don't forget you can wrap a dataframe with display() in DataBricks to get the downloadable CSV to your local machine, if you wanna share it with stakeholders, or do some Excel things with it.

# displaying it as a DataBricks table, so I can download the CSV to my computer
overlap_summary['table_name']= overlap_summary.index
display(overlap_summary[['table_name']+repeated_columns])

This will show you a table like this,

Testing refresh rate

This is done by simply checking the last timestamp available, and cross-referencing that with the date you collected it on. Repeat that daily, weekly, and monthly to verify refresh rate, and/or to verify your data stewards/engineers word

# a step to verify the timestamp column exist in all tables, if not, use another timestamp column that exist in those tables. 
for n,item in enumerate(tables_list):
	if TIMESTAMP_COLUMN not in table.columns:
		print(tables_list_names[n])
		
# I didn't find the TIMESTAMP_COLUMN in a couple of tables, say 9 and 11, which one to use instead that exist in both, for ease of processing? uncover that now
set([col[0] for col in table9.dtypes if col[1]=='timestamp']).intersection(set([col[0] for col in table11.dtypes if col[1]=='timestamp']))

# now see latest timestamp
for n,item in enumerate(tables_list):
	if TIMESTAMP_COLUMN in item.columns:
		print(tables_list_names[n])
		print(item.select(f.max(TIMESTAMP_COLUMN)).show())
		print("\n")
	else:
		print(tables_list_names[n])
		print(item.select(f.max(f.col("other_timestamp"))).show())
		print("\n")

This will have a printout like this table1

Show data range, via oldest and newest record

dates_results0=[]
for n,table in enumerate(tables_list):
	if tables_list_names[n] not in ['table9', 'table11']:
		temp= table.select(f.min(f.col(TIMESTAMP_COLUMN)), f.max(f.col(TIMESTAMP_COLUMN)))\
		.toPandas().rename(index={0:tables_list_names[n]})
		dates_results0.append(temp)
	else:
		pass
dates_results1= pd.concat(dates_results0, axis=0)

# for those two tables {'table9', 'table11'} that didn't have the same TIMESTAMP_COLUMN, we're using "other_timestamp" column
dates_results2= table9.select(f.min(f.col("other_timestamp")), f.max(f.col("other_timestamp"))).toPandas().rename(index={0:'table9'}).append(table11.select(f.min(f.col("other_timestamp")), f.max(f.col("other_timestamp"))).toPandas().rename(index={0:'table11'}))

dates_results1 Will look something like this,

Similarly for dates_results2

Samples of dataframe to test merging

Joining tables is one of the expensive processes to do with Spark, you might want to take a small sample of the tables you're trying to join, so you test out on and check their data, until you're confident in the direction of your joins and everything else. To do so, I'm gonng loop through all the tables, and take a random 100 rows of them

sample_clean_tables=[]
for item in clean_tables:
	sample_clean_tables.append(item.orderBy(f.rand()).limit(100))
	
#source: https://stackoverflow.com/a/46015758/11381214

Last updated