PySpark Basics
February 10, 2025
read_csv())printSchema() and describe()select()groupBy().count(), countDistinct(), and count()orderBy()withColumn()drop()withColumnRenamed()selectExpr()cast()filter()na.drop(), na.fill(), etc.)dropDuplicates()) groupBy().agg()DataFrameIn PySpark, a DataFrame is a distributed collection of data organized into named columns.
Unlike Pandas DataFrame, a Spark DataFrame is evaluated lazily: many transformations are “planned” but not executed until an action (e.g., count(), collect()) triggers a computation on the cluster.
No dedicated row index is maintained like in Pandas; rows are conceptually identified by their values, not by a numeric position.
SparkSession Entry PointThe SparkSession entry point provides the functionality for data transformation with data frames and SQL.
from pyspark.sql import SparkSession
SparkSession class from PySpark’s SQL module.SparkSession.builder
SparkSession..master("local[*]")
.getOrCreate()
SparkSession if one exists, or creates a new one otherwise.path = '/content/drive/MyDrive/lecture-data/cces.csv'
df = spark.read.csv(path,
inferSchema=True,
header=True)
df.show() # Displays the first 20 rowsspark.read.csv(path, ...):
inferSchema=True:
True, Spark will automatically detect (or “infer”) the data types of the columns in the CSV file.header=True:
path = 'https://bcdanl.github.io/data/df.csv'
df = spark.read.csv(path,
inferSchema=True,
header=True)
df.show()Spark’s spark.read.csv() function relies on the Hadoop FileSystem API to access files.DataFrame to a Spark DataFramedf = spark.createDataFrame(df_pd)df.printSchema(): prints the schema (column names and data types).df.columns: returns the list of columns.df.dtypes: returns a list of tuples (columnName, dataType).df.count(): returns the total number of rows.df.describe(): returns basic statistics of numerical/string columns (mean, count, std, min, max).df.show()By default, displays the first 20 rows.
df.show(n, truncate, vertical) accepts three optional parameters:
n (int): Number of rows to display.df.show(5)df.show()truncate (bool or int):True (default), long strings are truncated to 20 characters.False, displays full column contents.df.show(truncate=False), df.show(truncate=30)df.show()vertical (bool):True, displays each row vertically (useful for wide tables).df.show(vertical=True)df.printSchema(), df.dtypes, and df.columnsdf.printSchema() prints the DataFrame schema in a tree format.
df.dtypes returns a list of tuples representing each column’s name and data type.
df.columns returns a list of colunm names.
df.describe().show()df.describe() computes summary statistics (e.g., count, mean, stddev, min, max) for the DataFrame’s numeric columns..show() prints these statistics in a readable table format.# Single column -> returns a DataFrame with one column
df.select("Name").show(5)
# Multiple columns -> pass a list-like of column names
df.select("Name", "Team", "Salary").show(5)select() to choose variables.DataFrame projection of the specified variables.# Counting how many total rows
nba_count = df.count()
# Count distinct values in one column
from pyspark.sql.functions import countDistinct
num_teams = df.select(countDistinct("Team")).collect()[0][0]
# GroupBy a column and count occurrences
df.groupBy("Team").count().show(5)df.count() returns the number of rows in df.
Unlike Pandas, there is no direct .value_counts() or .nunique() in PySpark.
groupBy().count(), countDistinct, etc.).Let’s do Questions 1-3 in Classwork 5!
# Sort by a single column ascending
df.orderBy("Name").show(5)
# Sort by descending
from pyspark.sql.functions import desc
df.orderBy(desc("Salary")).show(5)
# Sort by multiple columns
df.orderBy(["Team", desc("Salary")]).show(5)orderBy() can accept column names and ascending/descending instructions.nsmallest or nlargest# nsmallest example:
df.orderBy("Salary").limit(5).show()
# nlargest example:
df.orderBy(desc("Salary")).limit(5).show()nsmallest() or nlargest(), but we can use limit() after sorting.DataFrames do not use a row-based index, so there is no direct .loc[] or .iloc[]..filter()) or use transformations (limit(), take(), collect()) to access row data.# Example: filter by condition
df.filter("Team == 'New York Knicks'").show()
df.limit(5).show()
df.take(5)
df.collect()n rows, use df.limit(n) or df.take(n), which returns a list of Row objects.df.collect(): Returns all the records as a list of Row.withColumn()# Add a column "Salary_k" using a column expression col()
df = df.withColumn("Salary_k", col("Salary") / 1000) drop()df = df.drop("Salary_k") # remove a single column
df = df.drop("Salary_2x", "Salary_3x") # remove multiple columnswithColumnRenamed()# Summaries for numeric columns
df.selectExpr(
"mean(Salary) as mean_salary",
"min(Salary) as min_salary",
"max(Salary) as max_salary",
"stddev_pop(Salary) as std_salary"
).show()selectExpr().from pyspark.sql import functions as F
# Pre-compute the average salary (pulls it back as a Python float)
salary_mean = df.select(F.avg("Salary").alias("mean_salary")).collect()[0]["mean_salary"]
df2 = (
df
.withColumn("Salary_2x", F.col("Salary") * 2) # Add Salary_2x
.withColumn(
"Name_w_Position", # Concatenate Name and Position
F.concat(F.col("Name"), F.lit(" ("), F.col("Position"), F.lit(")")))
.withColumn(
"Salary_minus_Mean", # Subtract mean salary
F.col("Salary") - F.lit(salary_mean))
)show(), collect(), count()) is called..alias() method is a way to give a temporary (or alternate) name to the column.cast() Methodcast() Methodcast():to_date() Methodto_date() can be used with a given string format (e.g., “M/d/yy”)byte (8-bit)short (16-bit)int (32-bit)long (64-bit)float (32-bit floating point)double (64-bit floating point)decimal (Arbitrary precision numeric type)string (Text data)boolean (Boolean values (True/False))date (Represents a date (year, month, day))timestamp (Represents a timestamp (date and time))import pandas as pd
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
df_pd = pd.read_csv("https://bcdanl.github.io/data/employment.csv")
df_pd = df_pd.where(pd.notnull(df_pd), None) # Convert NaN to None
df = spark.createDataFrame(df_pd)
df.filter(col("Salary") > 100000).show()isin() methodisin() method.between() methodTrue denotes that an observation’s value falls between the specified interval.Let’s do Questions 2-6 in Classwork 6!
isNull() and isNotNull() methodsnull..na.drop() method.na.drop() method removes observations that hold any NULL values..na.drop() method with howWe can pass the how parameter an argument of "all" to remove observations in which all values are missing.
Note that the how parameter’s default argument is "any".
.na.drop() method with subsetsubset parameter to target observations with a missing value in a specific variable.
Gender and Team variables..na.fill() methodvalue and subset parameters to fill a specific column’s NULLs with a specific valuedistinct()distinct() method returns a new DataFrame with duplicate rows removed.SELECT DISTINCT command.distinct(), only unique observation remain.dropDuplicates() method# Drop all rows that are exact duplicates across all columns
df_no_dups = df.dropDuplicates()
# Drop duplicates based on subset of columns
df_no_dups_subset = df.dropDuplicates(["Team"])dropDuplicates() keeps the first occurrence of each distinct combination.Let’s do Questions 7-8 in Classwork 6!
groupBy() (similar to Pandas’ groupby()) to aggregate, analyze, or transform data at a grouped level.GroupedData object is returned, which can then be used with aggregation methods such as sum(), avg(), count(), etc.DataFrame from a list (or other data sources like CSV, Parquet, etc.).groupBy("Type") on the DataFrame.food_data = [
("Apple", "Fruit", 1.05),
("Onion", "Vegie", 1.00),
("Orange", "Fruit", 1.25),
("Tomato", "Vegie", 0.85),
("Watermelon", "Fruit", 4.15)
]
food_df = spark.createDataFrame(food_data, ["Item", "Type", "Price"])
# Group by "Type"
groups = food_df.groupBy("Type")# Calculate the average Price for each Type
groups.avg("Price").show()
# Calculate the sum of the Price for each Type
groups.sum("Price").show()
# Count how many rows in each Type
groups.count().show().show()) is called.from pyspark.sql.functions import min, max, mean
food_df.groupBy("Type").agg(
min("Price").alias("min_price"),
max("Price").alias("max_price"),
mean("Price").alias("mean_price")
).show().agg() to get multiple results at once..transform() is often used to add group-level statistics back onto the original DataFrame.Window function with the aggregated DataFrame.from pyspark.sql.window import Window
from pyspark.sql.functions import avg, col
# Define a window partitioned by "Type"
w = Window.partitionBy("Type")
food_df_with_mean = food_df.withColumn(
"mean_price_by_type",
avg(col("Price")).over(w)
)
food_df_with_mean.show()Let’s do Classwork 7!