Logistic Regression

A model for binary classification

class date

April 26, 2023

Discuss Reading Questions PDF

The framework that we used to build a predictive model for regression followed four distinct steps:

  1. Decide on the mathematical form of the model: we used a linear model with potential for transformations and polynomials
  2. Select a metric that defines the “best” fit: we used the residual sum of squares (RSS)
  3. Estimating the coefficients of the model that are best using the training data: we found the coefficients that minimized RSS on the training data
  4. Evaluating predictive accuracy using a test data set: we calculated testing \(R^2\) using ~20% of the data that was withheld

We used this process to build our first simple linear regression model to predict high school graduation rates but the same four steps are used by Zillow to build ZestimateTM, their deep learning model to predict house price.

These are also the same four steps that we will use as we shift in these notes to the task of classification. In a classification task, we seek to predict a response variable that is categorical, very often a two-level categorical variable. Classification models are everywhere: they help doctors determine whether or not a patient has a disease, whether or not an image contains a particular object (say, a person or a cat), and whether or not customer will purchase an item.

The last time you checked your email inbox, you observed the results of another yet very useful classification model: a spam filter. There are many different types classification classification model that we could be used as a spam filter but in this class we’ll focus on one of the most common, called logistic regression.

Example: Spam filters

Email spam, also referred to as junk email or simply spam, is unsolicited messages sent in bulk by email (spamming). The name comes from a Monty Python sketch in which the name of the canned pork product Spam is ubiquitous, unavoidable, and repetitive1. A spam filter is a classification model that determines whether or not a message is spam or not based on properties of that message. Every mainstream email client, including Gmail, has built into it a spam filter to ensure that the messages that get through to the user are genuine messages.

The Data

In order to build a spam filter, we’ll need a data set to train our model upon. This data set must have as the unit of observation a single email message, record whether or not the message was spam (the response), and record various features of the message that are associated with being spam (the predictors). Such a data set, with nearly 4000 observations, can be found in the email data frame in the openintro library.

# A tibble: 3,921 × 21
   spam  to_multiple from     cc sent_…¹ time                image attach dollar
   <fct> <fct>       <fct> <int> <fct>   <dttm>              <dbl>  <dbl>  <dbl>
 1 0     0           1         0 0       2011-12-31 22:16:41     0      0      0
 2 0     0           1         0 0       2011-12-31 23:03:59     0      0      0
 3 0     0           1         0 0       2012-01-01 08:00:32     0      0      4
 4 0     0           1         0 0       2012-01-01 01:09:49     0      0      0
 5 0     0           1         0 0       2012-01-01 02:00:01     0      0      0
 6 0     0           1         0 0       2012-01-01 02:04:46     0      0      0
 7 0     1           1         0 1       2012-01-01 09:55:06     0      0      0
 8 0     1           1         1 1       2012-01-01 10:45:21     1      1      0
 9 0     0           1         0 0       2012-01-01 13:08:59     0      0      0
10 0     0           1         0 0       2012-01-01 10:12:00     0      0      0
# … with 3,911 more rows, 12 more variables: winner <fct>, inherit <dbl>,
#   viagra <dbl>, password <dbl>, num_char <dbl>, line_breaks <int>,
#   format <fct>, re_subj <fct>, exclaim_subj <dbl>, urgent_subj <fct>,
#   exclaim_mess <dbl>, number <fct>, and abbreviated variable name ¹​sent_email

We see in the left most column the response variable, spam, coded as 0 when a message is not spam and 1 when it is. The first predictor, to_multiple, is a 1 if the message was sent to multiple people and 0 otherwise. cc records the number of people who are cc’ed on the email. winner (listed as a variable below the glimpse of the dataframe) records whether or not the word “winner” showed up in the message. The remaining predictors you may be able to intuit by their names, but you can also read the help file that describes each one.

These variables seem like they might be useful in predicting whether or not an email is spam, but take a moment to consider: how were we able to get our hands on data like this?

This particular data set arose from selecting a single email account, saving every single messages that comes in to that address over the course of a month, processing each message to create values for the predictor variables, then visually classifying whether or not the message is spam. That’s to say: this data represents a humans’ best effort to classify as spam or not2. Can we build a model that will be able to identify the features that mark a message as spam to be able to automatically classify future messages?

Exploratory Data Analysis

Let’s see how well a few tempting predictors work at separating spam from not spam by performing some exploratory data analysis. I would expect messages containing the word “winner” to be more likely to be spam than those that do not. A stacked bar chart can answer that question.

ggplot(email, aes(x = winner, fill = spam)) +
  geom_bar(position = "fill") +

Indeed, it looks like around 30% of emails with “winner” were spam, compared to roughly 10% of those without. At this point, we could consider a very simple spam filter: if the message contains “winner”, then classify it as spam.

Although this is tempting, it is still a pretty weak classifier. Most of the messages with “winner” are not spam, so calling them spam will result in most of them being misclassified.

So if “winner” isn’t the silver bullet predictor we need, let’s try another: num_char. This variable records the total number of characters present in the message – how long it is. I have no prior sense of whether spam would be more likely to consist of short or long emails, so let’s visualize them. This predictor is continuous, so let’s overlay two density plots to get a sense of the distribution between spam and not spam.

p_numchar <- ggplot(email, aes(x = num_char, fill = spam)) +
  geom_density(alpha = .4) +

p_log_numchar <- ggplot(email, aes(x = log(num_char), fill = spam)) +
  geom_density(alpha = .4) +

p_numchar + p_log_numchar

The original plot on the left is very difficult to read because this variable is heavily right-skewed: there are a small number of very long messages that obscure much of the data in the plot. On the right is a more useful visualization, one of the log-transformed version of the same variable.

We see a reasonable separation here: spam messages tend to be shorter than non-spam messages. We could consider another very simple spam filter: if the log number of characters is less than 1.2, classify it as spam.

This simple filter suffers from the same problem as the first. Although these density plots have the same area, there are in fact far fewer overall instances of spam than not-spam. That means that there are far more not-spam messages with a log number of characters less than 1.2 than there are spam message. This filter, like the first, would misclassify much of the training data.

What we need is a more general framework to fold the strength of these two variables - as well as many of the other ones in the data set - into a model that can produce a single, more accurate prediction. The model framework that we’ll use is multiple linear regression.

From Linear to Logistic Regression

Let’s start taking our existing simple linear regression model from the past few notes and applying it in this classification setting. We can visualize the relationship between log_num_char and spam using an ordinary scatter plot.

This is a strange looking scatter plot - the dots can only take y values of 0 or 1 - but it does capture the overall trend observed in the density plots, that spam messages are longer. Since we have a scatter plot, we can fit a simple linear regression model using the method of least squares (in gold).

\[ \hat{y} = b_0 + b_1 x \]

While this is doable, it leaves us with a bit of a conundrum. For a low value of log_num_char it’s possible that we would predict \(\hat{y} = 1\) and for a high value it’s possible that we’d predict \(\hat{y} = 0\). But what if log_num_char is somewhere in the middle? What does it mean if we predict that the value of spam is .71?

One approach to resolving this is to treat our prediction not as a value of \(y\), but as a estimate of the probability, \(\hat{p}\), that \(y = 1\). We can rewrite the model as:

\[ \hat{p} = b_0 + b_1 x \]

This resolves the conundrum of how to think about a prediction of .71. That is now the model’s determination of the probability that the message is spam. This tweak solves one problem, but it introduces another. How do we interpret predictions at very high values of log_num_char, where \(\hat{p}\) is negative? Surely a probability cannot be negative!

We can fix this by changing the mathematical form of the model used to predict the response. Instead of it being a line, we can use an alternative function that prevents predictions greater than 1 and less than zero. The most commonly used function is called the standard logistic function:

\[ f(z) = \frac{1}{1 + e^{-z}}\]

\(z\) can be any number of the real number line. As \(z\) gets large, \(f(z)\) approaches 1; as \(z\) is negative, \(f(z)\) approaches 0; when \(z\) is 0, \(f(z)\) is .5.

This is a very clever idea. It allows us to combine all of the information from our predictors into a single numerical score, \(z\), which can then be sent through the logistic function to estimate the probability that \(y = 1\). This method is called logistic regression.

Logistic Regression (for prediction)

A model to predict the probability that 0-1 response variable \(y\) is 1 using the inverse logit of a linear combination of predictors \(x_1, x_2, \ldots, x_p\).

\[ \hat{p} = \frac{1}{1 + e^{-\left(b_0 + b_1 x_1 + \ldots + b_p x_p\right)}} \]

Can be used as a classification model by setting up a rule of the form: if \(\hat{p}_i\) > threshold, then \(\hat{y}_i = 1\).

We can visualize the approach that logistic regression takes by sketching the predictions as a green s-shaped curve on top of our scatter plot.

Fitting Logistic Regression

The computational machinery for fitting logistic regression looks almost identical to what we used for linear least squares regression. The primary function we’ll use is glm().

m1 <- glm(spam ~ log(num_char), data = email, family = "binomial")

Call:  glm(formula = spam ~ log(num_char), family = "binomial", data = email)

  (Intercept)  log(num_char)  
      -1.7244        -0.5435  

Degrees of Freedom: 3920 Total (i.e. Null);  3919 Residual
Null Deviance:      2437 
Residual Deviance: 2190     AIC: 2194

These coefficients are a bit more challenging to interpret since they’re no longer linearly related to the response. The sign of the coefficient for log(num_char), however, is informative: it tells us that messages with more characters will be predicted to have a lower probability of being spam.

Let’s take a look at the predictions that this model makes back into the data set that it was trained on (also called the fitted values). When using the predict() function on logistic regression models, there are several different types of predictions that it can return, so be sure to specify type = "response".

p_hat_m1 <- predict(m1, email, type = "response")

We can then move from values of \(\hat{p}_i\) to values of \(y_1\) by checked to see whether each value of \(\hat{p}_i\) is greater than .5. Let’s combine these values into one data frame.

predictions <- tibble(y = email$spam) %>%
  mutate(p_hat_m1 = p_hat_m1,
         y_hat_m1 = p_hat_m1 > .5)
# A tibble: 3,921 × 3
   y     p_hat_m1 y_hat_m1
   <fct>    <dbl> <lgl>   
 1 0       0.0454 FALSE   
 2 0       0.0473 FALSE   
 3 0       0.0553 FALSE   
 4 0       0.0419 FALSE   
 5 0       0.137  FALSE   
 6 0       0.145  FALSE   
 7 0       0.0704 FALSE   
 8 0       0.0566 FALSE   
 9 0       0.0886 FALSE   
10 0       0.0951 FALSE   
# … with 3,911 more rows

These first 10 rows were all examples of not-spam (\(y = 0\)), and our model corresponding estimate that the probabilities that each one is spam based on the value of log_num_char is quite low, ranging from .04 to .14. Since these are all less than .5, every prediction (or fitted value) was 0.

How well did the model do when it predicted very high values of \(\hat{p}_i\). Let’s take a look by sorting the data frame in descending order.

predictions %>%
# A tibble: 3,921 × 3
   y     p_hat_m1 y_hat_m1
   <fct>    <dbl> <lgl>   
 1 1        0.884 TRUE    
 2 1        0.884 TRUE    
 3 0        0.807 TRUE    
 4 0        0.807 TRUE    
 5 0        0.782 TRUE    
 6 0        0.782 TRUE    
 7 0        0.742 TRUE    
 8 0        0.726 TRUE    
 9 0        0.711 TRUE    
10 1        0.636 TRUE    
# … with 3,911 more rows

There were two messages that the model assigned a probability of .884 that they were spam and indeed they were. But then the model made several errors. Messages 3 through 9 were all also predicted to be spam but in fact they were not.

It seems, on the surface that this simple logistic regression model, though it is functional, is not very accurate. We’ll aim to increase its accuracy tomorrow in class.


In these notes we introduced the concept of classification using a logistic regression model. Logistic regression uses the logistic function to transform predictions into a probability that the response is 1. These probabilities can be used to classify y as 0 or 1 by checked to see if they exceed a threshold (often .5).

These notes focus on step 1 of the predictive model building process: deciding on the mathematical form of the model. In class tomorrow will take up the remaining three as we learn how to use apply logistic regression as a classification model to complete the eighth and final lab of Stat 20.

Materials from class



  1. Definition from Wikipedia, along with the image, by freezelight/flickr.↩︎

  2. This data collection and manual processing was done by a graduate student in statistics at UCLA. One of the many humdrum but valuable tasks asked of graduates students . . .↩︎