from pyspark.sql import SparkSession
import pyspark.sql.functions as F
def main():
spark = SparkSession.builder.appName("BroadcastJoinMisuse").getOrCreate()
# Create a large table (should NOT be broadcast)
large_df = spark.range(1000000).toDF("id")
large_df = large_df.withColumn("value", F.rand() * 100)
# Create a small table (good candidate for broadcast)
small_df = spark.range(100).toDF("id")
small_df = small_df.withColumn("category", F.lit("A"))
# BAD: Forcing broadcast on large table
# This will cause driver OOM or very slow performance
result = large_df.join(
F.broadcast(large_df.alias("large2")),
large_df.id == F.col("large2.id"),
"inner"
)
print(f"Count: {result.count()}")
spark.stop()
if __name__ == "__main__":
main()