PySpark Basics
February 10, 2025
read_csv()
)printSchema()
and describe()
select()
groupBy().count()
, countDistinct()
, and count()
orderBy()
withColumn()
drop()
withColumnRenamed()
selectExpr()
cast()
filter()
na.drop()
, na.fill()
, etc.)dropDuplicates()
) groupBy().agg()
DataFrame
In PySpark, a DataFrame
is a distributed collection of data organized into named columns.
Unlike Pandas DataFrame
, a Spark DataFrame
is evaluated lazily: many transformations are “planned” but not executed until an action (e.g., count()
, collect()
) triggers a computation on the cluster.
No dedicated row index is maintained like in Pandas; rows are conceptually identified by their values, not by a numeric position.
SparkSession
Entry PointThe SparkSession
entry point provides the functionality for data transformation with data frames and SQL.
from pyspark.sql import SparkSession
SparkSession
class from PySpark’s SQL module.SparkSession.builder
SparkSession
..master("local[*]")
.getOrCreate()
SparkSession
if one exists, or creates a new one otherwise.path = '/content/drive/MyDrive/lecture-data/cces.csv'
df = spark.read.csv(path,
inferSchema=True,
header=True)
df.show() # Displays the first 20 rows
spark.read.csv(path, ...)
:
inferSchema=True
:
True
, Spark will automatically detect (or “infer”) the data types of the columns in the CSV file.header=True
:
path = 'https://bcdanl.github.io/data/df.csv'
df = spark.read.csv(path,
inferSchema=True,
header=True)
df.show()
Spark’s spark.read.csv()
function relies on the Hadoop FileSystem API to access files.DataFrame
to a Spark DataFrame
df = spark.createDataFrame(df_pd)
df.printSchema()
: prints the schema (column names and data types).df.columns
: returns the list of columns.df.dtypes
: returns a list of tuples (columnName, dataType).df.count()
: returns the total number of rows.df.describe()
: returns basic statistics of numerical/string columns (mean, count, std, min, max).df.show()
By default, displays the first 20 rows.
df.show(n, truncate, vertical)
accepts three optional parameters:
n
(int): Number of rows to display.df.show(5)
df.show()
truncate
(bool or int):True
(default), long strings are truncated to 20 characters.False
, displays full column contents.df.show(truncate=False)
, df.show(truncate=30)
df.show()
vertical
(bool):True
, displays each row vertically (useful for wide tables).df.show(vertical=True)
df.printSchema()
, df.dtypes
, and df.columns
df.printSchema()
prints the DataFrame
schema in a tree format.
df.dtypes
returns a list of tuples representing each column’s name and data type.
df.columns
returns a list of colunm names.
df.describe().show()
df.describe()
computes summary statistics (e.g., count, mean, stddev, min, max) for the DataFrame’s numeric columns..show()
prints these statistics in a readable table format.# Single column -> returns a DataFrame with one column
df.select("Name").show(5)
# Multiple columns -> pass a list-like of column names
df.select("Name", "Team", "Salary").show(5)
select()
to choose variables.DataFrame
projection of the specified variables.# Counting how many total rows
nba_count = df.count()
# Count distinct values in one column
from pyspark.sql.functions import countDistinct
num_teams = df.select(countDistinct("Team")).collect()[0][0]
# GroupBy a column and count occurrences
df.groupBy("Team").count().show(5)
df.count()
returns the number of rows in df
.
Unlike Pandas, there is no direct .value_counts()
or .nunique()
in PySpark.
groupBy().count()
, countDistinct
, etc.).Let’s do Questions 1-3 in Classwork 5!
# Sort by a single column ascending
df.orderBy("Name").show(5)
# Sort by descending
from pyspark.sql.functions import desc
df.orderBy(desc("Salary")).show(5)
# Sort by multiple columns
df.orderBy(["Team", desc("Salary")]).show(5)
orderBy()
can accept column names and ascending/descending instructions.nsmallest
or nlargest
# nsmallest example:
df.orderBy("Salary").limit(5).show()
# nlargest example:
df.orderBy(desc("Salary")).limit(5).show()
nsmallest()
or nlargest()
, but we can use limit()
after sorting.DataFrames
do not use a row-based index, so there is no direct .loc[]
or .iloc[]
..filter()
) or use transformations (limit()
, take()
, collect()
) to access row data.# Example: filter by condition
df.filter("Team == 'New York Knicks'").show()
df.limit(5).show()
df.take(5)
df.collect()
n
rows, use df.limit(n)
or df.take(n)
, which returns a list of Row objects.df.collect()
: Returns all the records as a list of Row.withColumn()
# Add a column "Salary_k" using a column expression col()
df = df.withColumn("Salary_k", col("Salary") / 1000)
drop()
df = df.drop("Salary_k") # remove a single column
df = df.drop("Salary_2x", "Salary_3x") # remove multiple columns
withColumnRenamed()
# Summaries for numeric columns
df.selectExpr(
"mean(Salary) as mean_salary",
"min(Salary) as min_salary",
"max(Salary) as max_salary",
"stddev_pop(Salary) as std_salary"
).show()
selectExpr()
.from pyspark.sql import functions as F
# Pre-compute the average salary (pulls it back as a Python float)
salary_mean = df.select(F.avg("Salary").alias("mean_salary")).collect()[0]["mean_salary"]
df2 = (
df
.withColumn("Salary_2x", F.col("Salary") * 2) # Add Salary_2x
.withColumn(
"Name_w_Position", # Concatenate Name and Position
F.concat(F.col("Name"), F.lit(" ("), F.col("Position"), F.lit(")")))
.withColumn(
"Salary_minus_Mean", # Subtract mean salary
F.col("Salary") - F.lit(salary_mean))
)
show()
, collect()
, count()
) is called..alias()
method is a way to give a temporary (or alternate) name to the column.cast()
Methodcast()
Methodcast()
:to_date()
Methodto_date()
can be used with a given string format (e.g., “M/d/yy”)byte
(8-bit)short
(16-bit)int
(32-bit)long
(64-bit)float
(32-bit floating point)double
(64-bit floating point)decimal
(Arbitrary precision numeric type)string
(Text data)boolean
(Boolean values (True/False))date
(Represents a date (year, month, day))timestamp
(Represents a timestamp (date and time))import pandas as pd
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
df_pd = pd.read_csv("https://bcdanl.github.io/data/employment.csv")
df_pd = df_pd.where(pd.notnull(df_pd), None) # Convert NaN to None
df = spark.createDataFrame(df_pd)
df.filter(col("Salary") > 100000).show()
isin()
methodisin()
method.between()
methodTrue
denotes that an observation’s value falls between the specified interval.Let’s do Questions 2-6 in Classwork 6!
isNull()
and isNotNull()
methods.na.drop()
method.na.drop()
method removes observations that hold any NULL
values..na.drop()
method with how
We can pass the how
parameter an argument of "all"
to remove observations in which all values are missing.
Note that the how
parameter’s default argument is "any"
.
.na.drop()
method with subset
subset
parameter to target observations with a missing value in a specific variable.
Gender
and Team
variables..na.fill()
methodvalue
and subset
parameters to fill a specific column’s NULL
s with a specific valuedistinct()
distinct()
method returns a new DataFrame with duplicate rows removed.SELECT DISTINCT
command.distinct()
, only unique observation remain.dropDuplicates()
method# Drop all rows that are exact duplicates across all columns
df_no_dups = df.dropDuplicates()
# Drop duplicates based on subset of columns
df_no_dups_subset = df.dropDuplicates(["Team"])
dropDuplicates()
keeps the first occurrence of each distinct combination.Let’s do Questions 7-8 in Classwork 6!
groupBy()
(similar to Pandas’ groupby()
) to aggregate, analyze, or transform data at a grouped level.GroupedData
object is returned, which can then be used with aggregation methods such as sum()
, avg()
, count()
, etc.DataFrame
from a list (or other data sources like CSV, Parquet, etc.).groupBy("Type")
on the DataFrame
.food_data = [
("Apple", "Fruit", 1.05),
("Onion", "Vegie", 1.00),
("Orange", "Fruit", 1.25),
("Tomato", "Vegie", 0.85),
("Watermelon", "Fruit", 4.15)
]
food_df = spark.createDataFrame(food_data, ["Item", "Type", "Price"])
# Group by "Type"
groups = food_df.groupBy("Type")
# Calculate the average Price for each Type
groups.avg("Price").show()
# Calculate the sum of the Price for each Type
groups.sum("Price").show()
# Count how many rows in each Type
groups.count().show()
.show()
) is called.from pyspark.sql.functions import min, max, mean
food_df.groupBy("Type").agg(
min("Price").alias("min_price"),
max("Price").alias("max_price"),
mean("Price").alias("mean_price")
).show()
.agg()
to get multiple results at once..transform()
is often used to add group-level statistics back onto the original DataFrame
.Window
function with the aggregated DataFrame.
from pyspark.sql.window import Window
from pyspark.sql.functions import avg, col
# Define a window partitioned by "Type"
w = Window.partitionBy("Type")
food_df_with_mean = food_df.withColumn(
"mean_price_by_type",
avg(col("Price")).over(w)
)
food_df_with_mean.show()
Let’s do Classwork 7!