Add or remove items from array using PySpark

In this article, we will use HIVE and PySpark to manipulate complex datatype i.e. array<string>. We show how to add or remove items from array using PySpark

We will use datasets consist of three units, representing paye, crn and vat units. For sample data see – https://broadoakdata.uk/synthetic-data-creation-linking-records/

+-----+-------+----+
|   id|unitref|type|
+-----+-------+----+
|L2001|paye123|   p|
|L2006| vat123|   v|
|L2234| vat223|   v|
|L2234|paye223|   p|
|L3235|paye345|   P|
|L2007|  c0023|   c|
|L2238|  c7223|   c|
|L3239|  c9345|   c|
+-----+-------+----+

We need to link and unlink few units from the above dataset. Tasks to be completed.

  • link records which match crn
  • remove units from units which do not have any crn
+-----+---------+-----+------+
| s_id|s_unitref|  crn|s_type|
+-----+---------+-----+------+
|L2001|  paye123|c0023|     p|
|L2006|   vat123|c7223|     v|
|L2234|   vat223| null|     v|
|L3235|  paye345| null|     P|
+-----+---------+-----+------+

Result final output should like below:

+-----+-------+--------+---------+
|   id|   crns|    vats|    payes|
+-----+-------+--------+---------+
|L2001|     []|      []|       []|
|L2006|     []|      []|       []|
|L2234|     []|[vat223]|[paye223]|
|L3235|     []|      []|       []|
|L2007|[c0023]|      []|[paye123]|
|L2238|[c7223]|[vat123]|       []|
|L3239|[c9345]|      []|       []|
+-----+-------+--------+---------+

Approach for adding and removing items from array

  • units need linking and unlinking
  • dataframe with all units in the database
  • store units as set – group by id
  • identify units with crn
  • dataframe as set with crn
  • join two dataframes for linking
  • perform linking – add unit which matches crn
  • join two sets id which needs to be unlink
  • unlink – remove units

Code snippets

# units need linking and unlinking
df_vp = spark.sql("""
with vp(id,unitref,type) as (
select 'L2001','paye123','p'
union all
select 'L2006','vat123','v'
union all
select 'L2234','vat223','v'
union all
select 'L3235','paye345','P'
),
-- crn units
ogd(crn,unitref,type) as (
select 'c0023','paye123','p'
union all
select 'c7223','vat123','v'
)
select 
vp.id as s_id
,vp.unitref as s_unitref
,o.crn as crn
,vp.type as s_type
from vp
left join ogd o on vp.unitref=o.unitref and vp.type=o.type
""")
df_vp.cache().count()
df_vp.show()

# dataframe with all units in the database
df = spark.sql("""
with lunit(id,unitref,type) as (
select 'L2001','paye123','p'
union all
select 'L2006','vat123','v'
union all
select 'L2234','vat223','v'
union all
select 'L2234','paye223','p'
union all
select 'L3235','paye345','P'
union all
select 'L2007','c0023','c'
union all
select 'L2238','c7223','c'
union all
select 'L3239','c9345','c'
)
select * from lunit
""")
df.cache().count()
df.show()

# store units as set - group by id
df_set = (df
  .groupby("id")
  .agg(
      F.collect_set(
          F.when(F.col("type")=='c', F.col('unitref')).otherwise(F.lit(None))).alias("crns"),
      F.collect_set(
          F.when(F.col("type")=='v', F.col('unitref')).otherwise(F.lit(None))).alias("vats"),
      F.collect_set(
          F.when(F.col("type")=='p', F.col('unitref')).otherwise(F.lit(None))).alias("payes")
    )
      ).drop('type')
df_set.show()

# identify units with crn
df_crn = df.join(df_vp,[(df.unitref==df_vp.crn) & (df.type=='c')], how='inner')
df_crn.show()

# dataframe as set with crn
df_crn_set = (df_crn
  .groupby("id","s_id","crn")
  .agg(
      F.collect_set(
          F.when(F.col("type")=='c', F.col('unitref')).otherwise(F.lit(None))).alias("s_crns"),
      F.collect_set(
          F.when(F.col("s_type")=='v', F.col('s_unitref')).otherwise(F.lit(None))).alias("s_vats"),
      F.collect_set(
          F.when(F.col("s_type")=='p', F.col('s_unitref')).otherwise(F.lit(None))).alias("s_payes")
    )
      ).drop('type','s_type')
df_crn_set.show()

# join dataframes 
df_crn_join = df_set.join(df_crn_set,['id'],how='left')
df_crn_join.show()

# perform linking -  add unit which matches crn
df_link = (df_crn_join.withColumn('payes',
                     F.when(F.col('crn').isNotNull(),
                             F.array_union(F.col('payes'),F.col('s_payes')))
         .otherwise(F.col('payes')))
         .withColumn('vats',
                     F.when(F.col('crn').isNotNull(),
                             F.array_union(F.col('vats'),F.col('s_vats')))
         .otherwise(F.col('vats')))
     )
df_link.show()

# join two sets id which needs to be unlink
df_unlink_set = (df_link.select('id', 'crns', 'vats', 'payes')
       .join(df_crn_set.select('s_id','s_vats', 's_payes'),[df_link.id == df_crn_set.s_id],how='left')
      )
df_unlink_set.show()

# unlink - remove units
df_unlink = (df_unlink_set.withColumn('payes',
                     F.when(F.col('id') == F.col('s_id'),
                             F.array_except(F.col('payes'),F.col('s_payes')))
         .otherwise(F.col('payes')))
         .withColumn('vats',
                     F.when(F.col('id') == F.col('s_id'),
                             F.array_except(F.col('vats'),F.col('s_vats')))
         .otherwise(F.col('vats')))
     ).drop('s_crns','s_vats','s_payes','s_id')
df_unlink.show()

Screenshot