Saving data as parquet table with partition key

We will show how to add latest changes into parquet table with partition key.

Approach

  • Perform CDC and save changes to PARQUET table with partitioned key.
  • Split source dataframe into three sub dataframes
    • births
    • amends
    • death (records not present in source but exist in target dataframe (current records)

Code snippets

CDC implementation

def identify_changes(df_source,df_target,id_col=None,change_cols=None,tablename='cdc',changetable='cdc_changes'):
  """ performing CDC and saving changes to PARQUET partioned table
  partition_key - live_flag (int 0=false and 1=true)
  id_col - to perform outer join
  change_cols = tracking changes for given variables
  tblename - parquet table name for saving data
  changetable - table for change history (only append data into the table - partition key - loaddate)

  - source dataframe - split them into
      - birth (only present in the source)
      - death (compare last load data with source and record any missing records)
      - amends
  - current records dataframe - which are not present in the source data (only get data where live_flag=1)
  """
  logging.info('START - identify changes: {}'.format(datetime.now().strftime("%Y%m%d%H%M")))
  print('START - identify changes: {}'.format(datetime.now().strftime("%Y%m%d%H%M")))

  if id_col is None:
    id_col = df_source.columns[:1]
  if change_cols is None:
    change_cols = df_source.columns

  # identify births, deaths and changes - outer join with broadcast
  df_t = df_target.select(*(F.col(x).alias(x + '_x') for x in df_target.columns))
  source_target = (df_t.join(F.broadcast(df_source),df_source[id_col] == df_t[id_col+'_x'],'outer')
                        .withColumn('cdc',F.when(F.col(id_col+'_x').isNull(), F.lit(2))
                                           .when(F.col(id_col).isNull(), F.lit(99))
                                           .otherwise(F.lit(3)))
                        .select('*')
              )
  source_target = source_target.persist()

  # births
  births = source_target.filter(F.col(id_col+'_x').isNull()).select(df_source.columns)
  # no_changes)
  no_changes = source_target.filter(F.col(id_col).isNull()).select(df_t.columns)
  no_changes = lf.rename_cols(no_changes,'_x','')
  output = births.union(no_changes).drop('live_flag')

  # apply data changes
  conditions = [F.when(source_target[c] != source_target[c+'_x'], F.lit(c)).otherwise(F.lit("")) for c in change_cols if c not in [id_col]]
  select_expr = [
    *[F.coalesce(source_target[c+'_x'], source_target[c]).alias(c) for c in df_source.columns if c not in change_cols],
    *[F.coalesce(source_target[c], source_target[c+'_x']).alias(c) for c in change_cols if c not in [id_col]],
    F.array_remove(F.array(*conditions), "").alias("data_changes")
  ]
  changes = source_target.filter(source_target.cdc==3).select(*select_expr,'cdc')

  # save change history as parquet table
  changes.write.partitionBy('loaddate').mode('append').parquet(changetable)
  (source_target.filter(source_target.cdc!=3)
    .select(*df_source.columns,F.array(F.lit("")).alias('data_changes'),'cdc')
    .write.partitionBy('loaddate')
    .mode('append').parquet(changetable)
  )
  # save data
  changes = changes.select(*df_source.columns).drop('live_flag')
  output = changes.union(output)

  # overwrite live_flag=1
  output = output.withColumn('live_flag',F.lit(1))
  output.filter(output.live_flag==1).write.partitionBy('live_flag') \
    .mode('overwrite').option("partitionOverwriteMode", "dynamic").parquet(tablename)
  # need to keep history
  old_changes = source_target.filter(source_target.cdc==3).select(*df_t.columns)
  old_changes = lf.rename_cols(old_changes,'_x','')
  old_changes = old_changes.withColumn('live_flag',F.lit(0))
  old_changes.filter(old_changes.live_flag==0).write.partitionBy('live_flag').mode('append').parquet(tablename)

  logging.info('END - identify changes: {}'.format(datetime.now().strftime("%Y%m%d%H%M")))
  print('END - identify changes: {}'.format(datetime.now().strftime("%Y%m%d%H%M")))

  return output

Synthetic data

def crn_data(spark,start=0,end=100):
  crn = spark.sql(f"""
select
i+start as rowno,
lpad(cast(floor(rand() * 8000000) as varchar(7)),8,'7') as ch_no,
concat(Array("C1","C2","F1","V_C","V_P_C","C_P","F3")[floor(rand()*6+0)],floor(rand() * 20)) as ch_name,
floor(rand() * 3 + 1) as ch_legalstatus,
case
when rand() > 0.30 then
concat(Array("C1","C2","F1","V_C","V_P_C","C_P","F3")[floor(rand()*6+0)],floor(rand() * 20))
when rand() > 0.20 then
concat(Array(40021,60023,70089,99999,21000)[floor(rand()*5.0+0)],floor(rand() * 20))
else
Array(40021,60023,70089,99999,21000)[floor(rand()*5.0+1)]
end as rand_c
,Array(40021,60023,70089,99999,21000)[floor(rand()*5.0)] as sic2007
,source_file as source_file
,cast(1 as int) as live_flag
,current_timestamp() as load_timestamp
,current_date() as loaddate
from (select {start} as start, {end} as end, concat('df',{start}) as source_file) r
lateral view posexplode(split(space(end - start),' ')) pe as i,s
""")
  return crn

Example code for using above functions

# initial load
df_target = crn_data(spark,1,5)
df_target.write.partitionBy('live_flag').mode('append').parquet('ch')

# target data
df1 = spark.sql("select * from parquet.`file:/content/ch/live_flag=1`")

# source data
df2 = crn_data(spark,5,8).drop('live_flag')
skip_rows = list(range(5,6))
# will skip few
df2 = df2.filter(~df2.rowno.isin(skip_rows))

change_cols = [
  'ch_name',
  'ch_legalstatus',
  'sic2007'
]
id_col = 'rowno'
output = identify_changes(df2,df1,id_col,change_cols,'ch','ch_changes')