Logistic Regression

Newborn Babies

Author
Affiliation

Byeong-Hak Choe

SUNY Geneseo

Published

February 11, 2026

Modified

February 16, 2026

R Packages and Settings

library(tidyverse)
library(broom)
library(stargazer)
library(margins)
library(yardstick)
library(WVPlots)
library(DT)
library(rmarkdown)
library(hrbrthemes)
library(ggthemes)

theme_set(
  theme_ipsum()
)

scale_colour_discrete <- function(...) scale_color_colorblind(...)
scale_fill_discrete <- function(...) scale_fill_colorblind(...)

Data: NatalRiskData

df <- read_csv("https://bcdanl.github.io/data/NatalRiskData.csv")
paged_table(df)
Variable Type Description
atRisk Bool 1 if Apgar < 7, 0 otherwise
PWGT Num Prepregnancy weight
UPREVIS Int Prenatal visits
CIG_REC Bool 1 if smoker, 0 otherwise
GESTREC3 Cat < 37 weeks or ≥ 37 weeks
DPLURAL Cat Single / Twin / Triplet+
ULD_MECO Bool 1 if heavy meconium
ULD_PRECIP Bool 1 if labor < 3 hours
ULD_BREECH Bool 1 if breech birth
URF_DIAB Bool 1 if diabetic
URF_CHYPER Bool 1 if chronic hypertension
URF_PHYPER Bool 1 if pregnancy hypertension
URF_ECLAM Bool 1 if eclampsia

Quick Checks (Levels you should see)

df |> 
  count(GESTREC3)
# A tibble: 2 × 2
  GESTREC3        n
  <chr>       <int>
1 < 37 weeks   3005
2 >= 37 weeks 23308
df |> 
  count(DPLURAL)
# A tibble: 3 × 2
  DPLURAL               n
  <chr>             <int>
1 single            25440
2 triplet or higher    44
3 twin                829
df |> 
  count(atRisk)
# A tibble: 2 × 2
  atRisk     n
   <dbl> <int>
1      0 25831
2      1   482

Pre-process: Factor Levels (define levels BEFORE splitting)

This avoids predict() errors like “factor has new levels …”.

df <- df |>
  mutate(
    # set full factor levels using known categories
    GESTREC3 = factor(GESTREC3, levels = c(">= 37 weeks", "< 37 weeks")),
    DPLURAL  = factor(DPLURAL,  levels = c("single", "twin", "triplet or higher")),

    # choose reference categories (baseline)
    GESTREC3 = relevel(GESTREC3, ref = ">= 37 weeks"),
    DPLURAL  = relevel(DPLURAL,  ref = "single")
  )

Train/Test Split (runif(n()))

set.seed(1234)

df_split <- df |>
  mutate(rnd = runif(n()))

dtrain <- df_split |> filter(rnd > 0.3) |> select(-rnd)
dtest  <- df_split |> filter(rnd <= 0.3) |> select(-rnd)

Fit Logistic Regression (glm())

Model 1: baseline specification

model <- glm(
  atRisk ~ PWGT + UPREVIS + CIG_REC +
           ULD_MECO + ULD_PRECIP + ULD_BREECH +
           URF_DIAB + URF_CHYPER + URF_PHYPER + URF_ECLAM +
           GESTREC3 + DPLURAL,
  data = dtrain,
  family = binomial(link = "logit")
)

Regression Table (stargazer)

stargazer(
  model,
  type = "html",
  digits = 3,
  title = "Logistic regression (logit): atRisk"
)
Logistic regression (logit): atRisk
Dependent variable:
atRisk
PWGT 0.004***
(0.001)
UPREVIS -0.072***
(0.014)
CIG_REC 0.379**
(0.162)
ULD_MECO 1.034***
(0.189)
ULD_PRECIP 0.419
(0.273)
ULD_BREECH 0.830***
(0.158)
URF_DIAB -0.096
(0.238)
URF_CHYPER 0.0003
(0.433)
URF_PHYPER 0.106
(0.222)
URF_ECLAM -0.275
(1.052)
GESTREC3< 37 weeks 1.455***
(0.127)
DPLURALtwin 0.154
(0.223)
DPLURALtriplet or higher 1.252**
(0.528)
Constant -4.440***
(0.258)
Observations 18,584
Log Likelihood -1,551.710
Akaike Inf. Crit. 3,131.420
Note: p<0.1; p<0.05; p<0.01
  • Regression table with summary()
summary(model)

Call:
glm(formula = atRisk ~ PWGT + UPREVIS + CIG_REC + ULD_MECO + 
    ULD_PRECIP + ULD_BREECH + URF_DIAB + URF_CHYPER + URF_PHYPER + 
    URF_ECLAM + GESTREC3 + DPLURAL, family = binomial(link = "logit"), 
    data = dtrain)

Coefficients:
                           Estimate Std. Error z value Pr(>|z|)    
(Intercept)              -4.4401038  0.2582054 -17.196  < 2e-16 ***
PWGT                      0.0040278  0.0012848   3.135  0.00172 ** 
UPREVIS                  -0.0718017  0.0137076  -5.238 1.62e-07 ***
CIG_REC                   0.3790360  0.1617327   2.344  0.01910 *  
ULD_MECO                  1.0342848  0.1891973   5.467 4.58e-08 ***
ULD_PRECIP                0.4186892  0.2728621   1.534  0.12492    
ULD_BREECH                0.8295577  0.1576608   5.262 1.43e-07 ***
URF_DIAB                 -0.0961558  0.2378135  -0.404  0.68597    
URF_CHYPER                0.0002799  0.4327851   0.001  0.99948    
URF_PHYPER                0.1056108  0.2218410   0.476  0.63403    
URF_ECLAM                -0.2748580  1.0518845  -0.261  0.79386    
GESTREC3< 37 weeks        1.4550072  0.1266663  11.487  < 2e-16 ***
DPLURALtwin               0.1538631  0.2230274   0.690  0.49027    
DPLURALtriplet or higher  1.2524322  0.5280166   2.372  0.01769 *  
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

(Dispersion parameter for binomial family taken to be 1)

    Null deviance: 3394.5  on 18583  degrees of freedom
Residual deviance: 3103.4  on 18570  degrees of freedom
AIC: 3131.4

Number of Fisher Scoring iterations: 7

Beta Estimates (tidy())

model_betas <- tidy(model, 
                    conf.int = T)  # conf.level = 0.95 (default)

model_betas_90ci <- tidy(model, 
                       conf.int = T,
                       conf.level = 0.90)

model_betas_99ci <- tidy(model, 
                       conf.int = T,
                       conf.level = 0.99)

Coefficient Plots

month_ci <- bind_rows(
  model_betas_90ci |> mutate(ci = "90%"),
  model_betas      |> mutate(ci = "95%"),
  model_betas_99ci |> mutate(ci = "99%")
) |>
  mutate(ci = factor(ci, levels = c("90%", "95%", "99%")))

ggplot(
  data = month_ci |> 
    filter(!str_detect(term, "Intercept")),
  aes(
    y = term,
    x = estimate,
    xmin = conf.low,
    xmax = conf.high,
    color = ci
  )
) +
  geom_vline(xintercept = 0, color = "maroon", linetype = 2) +
  geom_pointrange(
    position = position_dodge(width = 0.6)
  ) +
  labs(
    color = "CI level",
    y = ""
  ) 

ggplot(
  data = month_ci |> 
    filter(term %in% c("UPREVIS", "PWGT")),
  aes(
    y = term,
    x = estimate,
    xmin = conf.low,
    xmax = conf.high,
    color = ci
  )
) +
  geom_vline(xintercept = 0, color = "maroon", linetype = 2) +
  geom_pointrange(
    position = position_dodge(width = 0.6)
  ) +
  labs(
    color = "CI level",
    y = ""
  ) 

ggplot(
  data = month_ci |> 
    filter(term %in% c("CIG_REC")),
  aes(
    y = term,
    x = estimate,
    xmin = conf.low,
    xmax = conf.high,
    color = ci
  )
) +
  geom_vline(xintercept = 0, color = "maroon", linetype = 2) +
  geom_pointrange(
    position = position_dodge(width = 0.6)
  ) +
  labs(
    color = "CI level",
    y = ""
  ) 

Model Fit (glance())

model |> 
  glance() |> 
  paged_table()

Marginal Effects (margins())

Average Marginal Effects

me <- margins(model)
sum_me <- summary(me)
sum_me
                   factor     AME     SE       z      p   lower   upper
                  CIG_REC  0.0066 0.0028  2.3303 0.0198  0.0011  0.0122
 DPLURALtriplet or higher  0.0393 0.0266  1.4732 0.1407 -0.0130  0.0915
              DPLURALtwin  0.0028 0.0044  0.6494 0.5161 -0.0057  0.0114
       GESTREC3< 37 weeks  0.0391 0.0051  7.7424 0.0000  0.0292  0.0490
                     PWGT  0.0001 0.0000  3.1037 0.0019  0.0000  0.0001
               ULD_BREECH  0.0145 0.0028  5.1292 0.0000  0.0089  0.0200
                 ULD_MECO  0.0180 0.0034  5.3008 0.0000  0.0114  0.0247
               ULD_PRECIP  0.0073 0.0048  1.5315 0.1256 -0.0020  0.0166
                  UPREVIS -0.0013 0.0002 -5.0975 0.0000 -0.0017 -0.0008
               URF_CHYPER  0.0000 0.0075  0.0006 0.9995 -0.0148  0.0148
                 URF_DIAB -0.0017 0.0041 -0.4042 0.6860 -0.0098  0.0065
                URF_ECLAM -0.0048 0.0183 -0.2613 0.7939 -0.0407  0.0312
               URF_PHYPER  0.0018 0.0039  0.4760 0.6341 -0.0057  0.0094

Marginal Effects for Selected Variables

summary(
  margins(model, 
          variables = c("PWGT", "UPREVIS", "CIG_REC"))
)
  factor     AME     SE       z      p   lower   upper
 CIG_REC  0.0066 0.0028  2.3303 0.0198  0.0011  0.0122
    PWGT  0.0001 0.0000  3.1037 0.0019  0.0000  0.0001
 UPREVIS -0.0013 0.0002 -5.0975 0.0000 -0.0017 -0.0008

Interpreting AMEs from margins() (Outcome: atRisk = 1)

AME = average change in \(\Pr(\text{atRisk}=1)\) when the predictor increases by 1 unit
(holding other variables at their observed values, then averaging across the sample).

1) CIG_REC (smoker indicator: 1 = smoker, 0 = non-smoker)

  • AME = 0.0066 (p = 0.0198; 95% CI [0.0011, 0.0122])
  • Interpretation: on average, changing CIG_REC from 0 → 1 (non-smoker → smoker) increases \(\Pr(\text{atRisk}=1)\) by 0.0066.
  • In percentage points: \(0.0066 \times 100 = 0.66\)+0.66 pp (95% CI: +0.11 to +1.22 pp)

Plain-English: babies of smokers have an estimated 0.66 percentage point higher probability of being atRisk, on average, compared to babies of non-smokers (holding other covariates constant on average).

2) PWGT = Prepregnancy weight (numeric)

  • AME = 0.0001 (p = 0.0019; 95% CI [0.0000, 0.0001])
  • Interpretation: a 1-unit increase in prepregnancy weight changes \(\Pr(\text{atRisk}=1)\) by +0.0001 on average.
  • In percentage points: +0.01 pp per 1 unit

Rescale (PWGT is in pounds):

  • +10 lb: \(10 \times 0.0001 = 0.0010\)+0.10 pp
  • +50 lb: \(50 \times 0.0001 = 0.0050\)+0.50 pp
  • +100 lb: \(100 \times 0.0001 = 0.0100\)+1.00 pp

3) UPREVIS = Prenatal visits (integer)

  • AME = -0.0013 (p < 0.001; 95% CI [-0.0017, -0.0008])
  • Interpretation: each additional prenatal visit changes \(\Pr(\text{atRisk}=1)\) by −0.0013 on average.
  • In percentage points: \(-0.0013 \times 100 = -0.13\)−0.13 pp per visit

Rescale: - +5 visits: −0.65 pp - +10 visits: −1.30 pp - +15 visits: −1.95 pp


## Marginal Effect Plots

df_ame <- sum_me |>
  as_tibble() |>
  rename(
    term = factor,
    ame  = AME,
    se   = SE
  )

# Pair each level with the correct z (NOT a Cartesian product)
ci_levels <- tibble(
  level = c("90%", "95%", "99%"),
  zcrit = c(1.645, 1.96, 2.576)
)

df_ame_ci <- df_ame |>
  tidyr::crossing(level = ci_levels$level) |>
  left_join(ci_levels, by = "level") |>
  mutate(
    conf.low  = ame - zcrit * se,
    conf.high = ame + zcrit * se
  )

ggplot(df_ame_ci,
       aes(x = fct_reorder(term, ame), y = ame)) +
  geom_hline(yintercept = 0, linetype = "dashed") +
  geom_pointrange(
    aes(ymin = conf.low, ymax = conf.high, color = level),
    position = position_dodge(width = 0.6)
  ) +
  coord_flip() +
  labs(
    x = NULL,
    y = "Average marginal effect (AME)",
    color = "CI level",
    title = "Average Marginal Effects (90/95/99% CIs)"
  )

ggplot(data = df_ame_ci |> 
         filter(term %in% c("PWGT", "UPREVIS") ), 
       aes(x = fct_reorder(term, ame), 
           y = ame)) +
  geom_hline(yintercept = 0, 
             linetype = "dashed") +
  geom_pointrange(
    aes(ymin = conf.low, 
        ymax = conf.high, 
        color = level),
    position = position_dodge(width = 0.6)
  ) +
  coord_flip() +
  labs(
    x = NULL,
    y = "Average marginal effect (AME)",
    color = "CI level",
    title = "Average Marginal Effects"
  )

ggplot(data = df_ame_ci |> 
         filter(term %in% c("CIG_REC") ), 
       aes(x = fct_reorder(term, ame), 
           y = ame)) +
  geom_hline(yintercept = 0, 
             linetype = "dashed") +
  geom_pointrange(
    aes(ymin = conf.low, 
        ymax = conf.high, 
        color = level),
    position = position_dodge(width = 0.6)
  ) +
  coord_flip() +
  labs(
    x = NULL,
    y = "Average marginal effect (AME)",
    color = "CI level",
    title = "Average Marginal Effects"
  )

Classification

Double density plot (threshold intuition)

threshold <- 0.02

# dtest <- dtest |>
#   mutate(.fitted = predict(model, 
#                            newdata = dtest, 
#                            type = "response"))

model |> 
  augment(type.predict = "response") |> 
  ggplot(aes(x = .fitted, 
             fill = factor(atRisk))) +
  geom_density(alpha = 0.35) +
  geom_vline(xintercept = threshold, linetype = "dashed") +
  labs(
    x = "Predicted probability",
    y = "Density",
    fill = "Actual class",
    title = "Training set: predicted probabilities by actual class"
  )

Confusion Matrix

threshold <- 0.02

df_cm <- model |>
  augment(newdata = dtest,
          type.predict = "response") |> 
  mutate(
    actual = factor(atRisk, 
                    levels = c(0, 1), 
                    labels = c("not at-risk", "at-risk")),
    pred   = factor(if_else(.fitted > threshold, 1, 0),
                    levels = c(0, 1), 
                    labels = c("not at-risk", "at-risk"))
  )

conf_mat <- conf_mat(df_cm, 
                     truth = actual, 
                     estimate = pred)

conf_mat
             Truth
Prediction    not at-risk at-risk
  not at-risk        6107      69
  at-risk            1480      73
# Tidy counts table
cm_counts <- df_cm |>
  group_by(actual, pred) |>
  summarize(n = n()) |>
  ungroup() |> 
  pivot_wider(names_from = pred, 
              values_from = n, 
              values_fill = 0)

Manual metrics (accuracy, precision, recall, specificity)

# Pull counts from df_cm (no matrix())
TN <- sum(df_cm$actual == "not at-risk" & df_cm$pred == "not at-risk")
TP <- sum(df_cm$actual == "at-risk" & df_cm$pred == "at-risk")
FN <- sum(df_cm$actual == "at-risk" & df_cm$pred == "not at-risk")
FP <- sum(df_cm$actual == "not at-risk" & df_cm$pred == "at-risk")

accuracy    <- (TP + TN) / (TP + FP + FN + TN)
precision   <- TP / (TP + FP)
recall      <- TP / (TP + FN)         # sensitivity
specificity <- TN / (TN + FP)

base_rate   <- mean(dtest$atRisk)
enrichment  <- precision / base_rate

df_class_metric <- 
  data.frame(
    metric = c("Accuracy", 
               "Precision", 
               "Recall (Sensitivity)", 
               "Specificity", 
               "Base rate", 
               "Enrichment"),
    value  = c(accuracy, 
               precision, 
               recall, 
               specificity, 
               base_rate, 
               enrichment)
    )

df_class_metric |> 
  datatable()

Precision/Recall/Enrichment Curves over Thresholds (using WVPlots package)

plt <- PRTPlot(df_cm, 
               ".fitted", "atRisk", 1,
               plotvars = c("enrichment", "precision", "recall", "specificity", "false_positive_rate"),
               thresholdrange = c(0,.1),
               title = "Enrichment vs. recall with threshold for natality model")

plt + 
  geom_vline(xintercept = threshold, 
             color="maroon", 
             linetype = 2)

Precision/Recall/Enrichment Curves over Thresholds (Manual)

scores <- df_cm |>
  mutate(
    y_true  = atRisk,
    y_score = .fitted
  ) |>
  select(y_true, y_score)

base_rate <- mean(scores$y_true)
threshold_grid <- seq(0, 1, by = 0.001)

# Pre-allocate vectors (fast + simple)
nT <- length(threshold_grid)
precision_v   <- rep(NA, nT)
recall_v      <- rep(NA, nT)
specificity_v <- rep(NA, nT)
enrichment_v  <- rep(NA, nT)

for (i in seq_along(threshold_grid)) {
  t <- threshold_grid[i]
  pred <- as.integer(scores$y_score > t)

  TP <- sum(scores$y_true == 1 & pred == 1)
  FP <- sum(scores$y_true == 0 & pred == 1)
  FN <- sum(scores$y_true == 1 & pred == 0)
  TN <- sum(scores$y_true == 0 & pred == 0)

  precision <- ifelse(TP + FP == 0, NA, TP / (TP + FP))
  recall    <- ifelse(TP + FN == 0, NA, TP / (TP + FN))
  specificity <- ifelse(TN + FP == 0, NA, TN / (TN + FP))
  enrichment  <- precision / base_rate

  precision_v[i]   <- precision
  recall_v[i]      <- recall
  specificity_v[i] <- specificity
  enrichment_v[i]  <- enrichment
}

curve_df <- tibble(
  threshold   = threshold_grid,
  precision   = precision_v,
  recall      = recall_v,
  specificity = specificity_v,
  enrichment  = enrichment_v
)

curve_df |>
  pivot_longer(cols = c(precision, recall, specificity, enrichment),
               names_to = "metric", values_to = "value") |>
  filter(metric %in% c("enrichment", "recall")) |> 
  ggplot(aes(x = threshold, y = value, color = metric)) +
  geom_line(linewidth = 0.8) +
  geom_vline(xintercept = threshold, color = 'maroon', lty = 2) +
  labs(
    title = "Enrichment-Recall Plot",
    x = "Threshold",
    y = "Metric value"
  ) +
  scale_x_continuous(limits = c(0, 0.1)) +
  scale_y_continuous(limits = c(0, 6))

ROC and AUC (using WVPlots package)

roc <- ROCPlot(df_cm,
               xvar = '.fitted',
               truthVar = 'atRisk',
               truthTarget = 1,
               title = 'Classifier performance')

# ROC with vertical line
roc + 
  geom_vline(xintercept = 1 - specificity, 
             color="maroon", linetype = 2)

# AUC
roc$plot_env$auc
[1] 0.6914125

Classifier Performance Can Shift with Base Rates

Performance of Classifier (NY → MA thought experiment)

  • Suppose you trained a classifier on NY hospital data with acceptable precision/recall.
  • Now you apply the same classifier to MA hospitals.
    • Will it perform as well?
  • Even if the relationship between features and risk is similar, the base rate (share of at-risk babies) in MA may differ.
    • This can change precision a lot, even when recall stays similar.

Create MA-like test sets with different at-risk rates (R)

set.seed(23464)

# Take 1,000 observations out of the test set (to "swap" base rates)
sample_indices <- sample(seq_len(nrow(dtest)), size = 1000, replace = FALSE)

separated <- dtest[sample_indices, ]
dtest_NY  <- dtest[-sample_indices, ]   # treat as "NY hospitals"

# Split the separated chunk into at-risk vs not-at-risk
at_risk_sample     <- separated |> filter(atRisk == 1)
not_at_risk_sample <- separated |> filter(atRisk == 0)

# MA test sets with different prevalence (base rates)
dtest_MA_moreRisk <- bind_rows(dtest_NY, at_risk_sample)       # add back only at-risk cases
dtest_MA_lessRisk <- bind_rows(dtest_NY, not_at_risk_sample)   # add back only not-at-risk cases

# Verify sizes
tibble(
  dataset = c("Original test", "Separated", "NY hospitals",
              "MA (more risk)", "MA (less risk)"),
  n = c(nrow(dtest), nrow(separated), nrow(dtest_NY),
        nrow(dtest_MA_moreRisk), nrow(dtest_MA_lessRisk)),
  at_risk_rate = c(mean(dtest$atRisk),
                   mean(separated$atRisk),
                   mean(dtest_NY$atRisk),
                   mean(dtest_MA_moreRisk$atRisk),
                   mean(dtest_MA_lessRisk$atRisk))
)
# A tibble: 5 × 3
  dataset            n at_risk_rate
  <chr>          <int>        <dbl>
1 Original test   7729       0.0184
2 Separated       1000       0.0320
3 NY hospitals    6729       0.0163
4 MA (more risk)  6761       0.0210
5 MA (less risk)  7697       0.0143

Evaluate the same classifier on each test set (fixed threshold)

df_cm_MA_moreRisk <- model |>
  augment(newdata = dtest_MA_moreRisk,
          type.predict = "response") |> 
  mutate(
    actual = factor(atRisk, 
                    levels = c(0, 1), 
                    labels = c("not at-risk", "at-risk")),
    pred   = factor(if_else(.fitted > threshold, 1, 0),
                    levels = c(0, 1), 
                    labels = c("not at-risk", "at-risk"))
  )

conf_mat_moreRisk <- conf_mat(df_cm_MA_moreRisk, 
                     truth = actual, 
                     estimate = pred)

conf_mat_moreRisk
             Truth
Prediction    not at-risk at-risk
  not at-risk        5318      69
  at-risk            1301      73
# Pull counts from df_cm_MA_moreRisk (no matrix())
TN <- sum(df_cm_MA_moreRisk$actual == "not at-risk" & df_cm_MA_moreRisk$pred == "not at-risk")
TP <- sum(df_cm_MA_moreRisk$actual == "at-risk" & df_cm_MA_moreRisk$pred == "at-risk")
FN <- sum(df_cm_MA_moreRisk$actual == "at-risk" & df_cm_MA_moreRisk$pred == "not at-risk")
FP <- sum(df_cm_MA_moreRisk$actual == "not at-risk" & df_cm_MA_moreRisk$pred == "at-risk")

accuracy    <- (TP + TN) / (TP + FP + FN + TN)
precision   <- TP / (TP + FP)
recall      <- TP / (TP + FN)         # sensitivity
specificity <- TN / (TN + FP)

base_rate   <- mean(dtest_MA_moreRisk$atRisk)
enrichment  <- precision / base_rate

df_class_metric_MA_moreRisk <- 
  data.frame(
    metric = c("Accuracy", 
               "Precision", 
               "Recall (Sensitivity)", 
               "Specificity", 
               "Base rate", 
               "Enrichment"),
    value  = c(accuracy, 
               precision, 
               recall, 
               specificity, 
               base_rate, 
               enrichment)
  )

df_class_metric_MA_moreRisk |> 
  rmarkdown::paged_table()
df_cm_MA_lessRisk <- model |>
  augment(newdata = dtest_MA_lessRisk,
          type.predict = "response") |> 
  mutate(
    actual = factor(atRisk, 
                    levels = c(0, 1), 
                    labels = c("not at-risk", "at-risk")),
    pred   = factor(if_else(.fitted > threshold, 1, 0),
                    levels = c(0, 1), 
                    labels = c("not at-risk", "at-risk"))
  )

conf_mat_lessRisk <- conf_mat(df_cm_MA_lessRisk, 
                              truth = actual, 
                              estimate = pred)

conf_mat_lessRisk
             Truth
Prediction    not at-risk at-risk
  not at-risk        6107      54
  at-risk            1480      56
# Pull counts from df_cm_MA_lessRisk (no matrix())
TN <- sum(df_cm_MA_lessRisk$actual == "not at-risk" & df_cm_MA_lessRisk$pred == "not at-risk")
TP <- sum(df_cm_MA_lessRisk$actual == "at-risk" & df_cm_MA_lessRisk$pred == "at-risk")
FN <- sum(df_cm_MA_lessRisk$actual == "at-risk" & df_cm_MA_lessRisk$pred == "not at-risk")
FP <- sum(df_cm_MA_lessRisk$actual == "not at-risk" & df_cm_MA_lessRisk$pred == "at-risk")

accuracy    <- (TP + TN) / (TP + FP + FN + TN)
precision   <- TP / (TP + FP)
recall      <- TP / (TP + FN)         # sensitivity
specificity <- TN / (TN + FP)

base_rate   <- mean(dtest_MA_lessRisk$atRisk)
enrichment  <- precision / base_rate

df_class_metric_MA_lessRisk <- 
  data.frame(
    metric = c("Accuracy", 
               "Precision", 
               "Recall (Sensitivity)", 
               "Specificity", 
               "Base rate", 
               "Enrichment"),
    value  = c(accuracy, 
               precision, 
               recall, 
               specificity, 
               base_rate, 
               enrichment)
  )
df_class_metric |> 
  rmarkdown::paged_table()
df_class_metric_MA_moreRisk |> 
  rmarkdown::paged_table()
df_class_metric_MA_lessRisk |> 
  rmarkdown::paged_table()
Back to top