We will show how to add latest changes into parquet table with partition key.
Approach
- Perform CDC and save changes to PARQUET table with partitioned key.
- Split source dataframe into three sub dataframes
- births
- amends
- death (records not present in source but exist in target dataframe (current records)
Code snippets
CDC implementation
def identify_changes(df_source,df_target,id_col=None,change_cols=None,tablename='cdc',changetable='cdc_changes'):
""" performing CDC and saving changes to PARQUET partioned table
partition_key - live_flag (int 0=false and 1=true)
id_col - to perform outer join
change_cols = tracking changes for given variables
tblename - parquet table name for saving data
changetable - table for change history (only append data into the table - partition key - loaddate)
- source dataframe - split them into
- birth (only present in the source)
- death (compare last load data with source and record any missing records)
- amends
- current records dataframe - which are not present in the source data (only get data where live_flag=1)
"""
logging.info('START - identify changes: {}'.format(datetime.now().strftime("%Y%m%d%H%M")))
print('START - identify changes: {}'.format(datetime.now().strftime("%Y%m%d%H%M")))
if id_col is None:
id_col = df_source.columns[:1]
if change_cols is None:
change_cols = df_source.columns
# identify births, deaths and changes - outer join with broadcast
df_t = df_target.select(*(F.col(x).alias(x + '_x') for x in df_target.columns))
source_target = (df_t.join(F.broadcast(df_source),df_source[id_col] == df_t[id_col+'_x'],'outer')
.withColumn('cdc',F.when(F.col(id_col+'_x').isNull(), F.lit(2))
.when(F.col(id_col).isNull(), F.lit(99))
.otherwise(F.lit(3)))
.select('*')
)
source_target = source_target.persist()
# births
births = source_target.filter(F.col(id_col+'_x').isNull()).select(df_source.columns)
# no_changes)
no_changes = source_target.filter(F.col(id_col).isNull()).select(df_t.columns)
no_changes = lf.rename_cols(no_changes,'_x','')
output = births.union(no_changes).drop('live_flag')
# apply data changes
conditions = [F.when(source_target[c] != source_target[c+'_x'], F.lit(c)).otherwise(F.lit("")) for c in change_cols if c not in [id_col]]
select_expr = [
*[F.coalesce(source_target[c+'_x'], source_target[c]).alias(c) for c in df_source.columns if c not in change_cols],
*[F.coalesce(source_target[c], source_target[c+'_x']).alias(c) for c in change_cols if c not in [id_col]],
F.array_remove(F.array(*conditions), "").alias("data_changes")
]
changes = source_target.filter(source_target.cdc==3).select(*select_expr,'cdc')
# save change history as parquet table
changes.write.partitionBy('loaddate').mode('append').parquet(changetable)
(source_target.filter(source_target.cdc!=3)
.select(*df_source.columns,F.array(F.lit("")).alias('data_changes'),'cdc')
.write.partitionBy('loaddate')
.mode('append').parquet(changetable)
)
# save data
changes = changes.select(*df_source.columns).drop('live_flag')
output = changes.union(output)
# overwrite live_flag=1
output = output.withColumn('live_flag',F.lit(1))
output.filter(output.live_flag==1).write.partitionBy('live_flag') \
.mode('overwrite').option("partitionOverwriteMode", "dynamic").parquet(tablename)
# need to keep history
old_changes = source_target.filter(source_target.cdc==3).select(*df_t.columns)
old_changes = lf.rename_cols(old_changes,'_x','')
old_changes = old_changes.withColumn('live_flag',F.lit(0))
old_changes.filter(old_changes.live_flag==0).write.partitionBy('live_flag').mode('append').parquet(tablename)
logging.info('END - identify changes: {}'.format(datetime.now().strftime("%Y%m%d%H%M")))
print('END - identify changes: {}'.format(datetime.now().strftime("%Y%m%d%H%M")))
return output
Synthetic data
def crn_data(spark,start=0,end=100):
crn = spark.sql(f"""
select
i+start as rowno,
lpad(cast(floor(rand() * 8000000) as varchar(7)),8,'7') as ch_no,
concat(Array("C1","C2","F1","V_C","V_P_C","C_P","F3")[floor(rand()*6+0)],floor(rand() * 20)) as ch_name,
floor(rand() * 3 + 1) as ch_legalstatus,
case
when rand() > 0.30 then
concat(Array("C1","C2","F1","V_C","V_P_C","C_P","F3")[floor(rand()*6+0)],floor(rand() * 20))
when rand() > 0.20 then
concat(Array(40021,60023,70089,99999,21000)[floor(rand()*5.0+0)],floor(rand() * 20))
else
Array(40021,60023,70089,99999,21000)[floor(rand()*5.0+1)]
end as rand_c
,Array(40021,60023,70089,99999,21000)[floor(rand()*5.0)] as sic2007
,source_file as source_file
,cast(1 as int) as live_flag
,current_timestamp() as load_timestamp
,current_date() as loaddate
from (select {start} as start, {end} as end, concat('df',{start}) as source_file) r
lateral view posexplode(split(space(end - start),' ')) pe as i,s
""")
return crn
Example code for using above functions
# initial load
df_target = crn_data(spark,1,5)
df_target.write.partitionBy('live_flag').mode('append').parquet('ch')
# target data
df1 = spark.sql("select * from parquet.`file:/content/ch/live_flag=1`")
# source data
df2 = crn_data(spark,5,8).drop('live_flag')
skip_rows = list(range(5,6))
# will skip few
df2 = df2.filter(~df2.rowno.isin(skip_rows))
change_cols = [
'ch_name',
'ch_legalstatus',
'sic2007'
]
id_col = 'rowno'
output = identify_changes(df2,df1,id_col,change_cols,'ch','ch_changes')