Introduction: Finding the Best Line
Imagine you're a real estate agent trying to estimate house prices. You notice a pattern: bigger houses tend to cost more. If you plot square footage vs. price, the data roughly follows a straight line. Linear regression is about finding the best line that fits this relationship.
Key Insight: Linear regression assumes the relationship between inputs (features) and output is approximately linear. It's the simplest, most interpretable ML algorithm – and surprisingly powerful!
In this lesson, we'll explore linear regression from multiple angles:
- Geometric intuition (fitting a line)
- Algebraic approach (normal equation)
- Optimization approach (gradient descent)
- Probabilistic interpretation (maximum likelihood)
Learning Objectives
- Understand the linear regression model
- Derive the cost function (MSE)
- Solve using the normal equation
- Implement gradient descent from scratch
- Interpret coefficients and make predictions
- Recognize when linear regression is appropriate
1. The Linear Model
Simple Linear Regression (One Feature)
The simplest case: predict (y) from a single feature (x):
[ \hat{y} = w_0 + w_1 x ]
- (w_0): intercept (bias) – value when (x = 0)
- (w_1): slope (weight) – change in (y) for unit change in (x)
- (\hat{y}): predicted value
Geometric View: This is the equation of a line!
Multiple Linear Regression (Multiple Features)
Real-world problems have many features. For a house: square footage, bedrooms, age, location, etc.
[ \hat{y} = w_0 + w_1 x_1 + w_2 x_2 + \cdots + w_d x_d = \mathbf{w}^T \mathbf{x} ]
In vector notation (adding bias to (\mathbf{x})): [ \hat{y} = \mathbf{w}^T \mathbf{x} = \begin{bmatrix} w_0 & w_1 & w_2 & \cdots & w_d \end{bmatrix} \begin{bmatrix} 1 \ x_1 \ x_2 \ \vdots \ x_d \end{bmatrix} ]
Key: The model is linear in the parameters (w_i), not necessarily in the features!
2. The Cost Function: Mean Squared Error
Measuring Goodness of Fit
How do we find the "best" line? We need a cost function (loss function) that quantifies how well our line fits the data.
Mean Squared Error (MSE): [ J(\mathbf{w}) = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}i)^2 = \frac{1}{n} \sum{i=1}^{n} (y_i - \mathbf{w}^T \mathbf{x}_i)^2 ]
Or in matrix form: [ J(\mathbf{w}) = \frac{1}{n} |\mathbf{y} - \mathbf{X}\mathbf{w}|^2 ]
Where:
- (\mathbf{X}): (n \times (d+1)) matrix (each row is a sample, first column is all 1s for intercept)
- (\mathbf{y}): (n \times 1) vector of targets
- (\mathbf{w}): ((d+1) \times 1) vector of weights
3. Solution #1: The Normal Equation (Analytical)
Closed-Form Solution
For linear regression, we can solve for (\mathbf{w}^*) analytically using calculus!
Derivation: Set the gradient of (J(\mathbf{w})) to zero:
[ \nabla_{\mathbf{w}} J = \frac{2}{n} \mathbf{X}^T (\mathbf{X}\mathbf{w} - \mathbf{y}) = \mathbf{0} ]
Solving for (\mathbf{w}):
[ \mathbf{X}^T \mathbf{X} \mathbf{w} = \mathbf{X}^T \mathbf{y} ]
[ \boxed{\mathbf{w}^* = (\mathbf{X}^T \mathbf{X})^{-1} \mathbf{X}^T \mathbf{y}} ]
This is the normal equation! It gives the exact optimal solution in one step.
Pros and Cons of Normal Equation
Pros:
- ✅ Exact solution (no iterations)
- ✅ No hyperparameters to tune
- ✅ Guaranteed global optimum
Cons:
- ❌ Requires matrix inversion: (O(d^3)) complexity
- ❌ Doesn't scale to large (d) (millions of features)
- ❌ Requires (\mathbf{X}^T \mathbf{X}) to be invertible
- ❌ Doesn't work for non-linear models
When to use: Small to medium datasets with (d < 10,000) features.
4. Solution #2: Gradient Descent (Iterative)
The Optimization Approach
Instead of solving analytically, we can iteratively improve (\mathbf{w}) by moving in the direction that reduces the cost.
Algorithm:
- Start with random weights (\mathbf{w})
- Compute gradient: (\nabla_{\mathbf{w}} J = \frac{2}{n} \mathbf{X}^T (\mathbf{X}\mathbf{w} - \mathbf{y}))
- Update weights: (\mathbf{w} \leftarrow \mathbf{w} - \alpha \nabla_{\mathbf{w}} J)
- Repeat until convergence
Where (\alpha) is the learning rate (step size).
Visualizing Gradient Descent
Before diving into code, let's build intuition by seeing how gradient descent works! The Gradient Descent Animator shows the optimization process step-by-step on different loss surfaces.
🗺️ Companion 3-D view from outside the platform: ML Visualized's linear-regression trainer shows the fitted line and the parameter-space descent path side by side. Seeing both at once is the moment "the slope of the loss is the update rule" stops being a slogan.
Implementing Gradient Descent from Scratch
Choosing the Learning Rate
The learning rate (\alpha) is critical:
| (\alpha) | Effect | Result |
|---|---|---|
| Too small | Slow convergence | Takes forever to train |
| Just right | Smooth, fast convergence | ✅ Optimal |
| Too large | Oscillations, divergence | Never converges! |
Rule of thumb: Start with (\alpha = 0.01) and adjust based on the cost curve.
5. Probabilistic Interpretation (Maximum Likelihood)
The Statistical View
Linear regression can be derived from a probabilistic perspective:
Assumption: Targets are noisy observations of a linear function: [ y_i = \mathbf{w}^T \mathbf{x}_i + \epsilon_i, \quad \epsilon_i \sim \mathcal{N}(0, \sigma^2) ]
Where (\epsilon_i) is Gaussian noise.
This means: [ y_i | \mathbf{x}_i, \mathbf{w} \sim \mathcal{N}(\mathbf{w}^T \mathbf{x}_i, \sigma^2) ]
Maximum Likelihood Estimation (MLE): Find (\mathbf{w}) that maximizes the probability of observing the data.
Result: MLE for linear regression with Gaussian noise is equivalent to minimizing MSE!
6. Model Interpretation and Diagnostics
Interpreting Coefficients
Each weight (w_j) tells us: "Holding all other features constant, a 1-unit increase in (x_j) leads to a (w_j)-unit change in (y)."
Residual Plots
Always check residuals to validate assumptions:
Diagnostic Checklist:
- ✅ Residuals centered at 0
- ✅ No pattern in residual plot
- ✅ Residuals roughly constant variance (homoscedastic)
- ✅ Residuals approximately Gaussian
7. When to Use Linear Regression
Appropriate When:
- ✅ Relationship is approximately linear
- ✅ You need interpretability (coefficients have meaning)
- ✅ You have sufficient data (more samples than features)
- ✅ Features are relatively independent (low multicollinearity)
Not Appropriate When:
- ❌ Relationship is clearly non-linear (use polynomial regression or other models)
- ❌ Severe outliers (consider robust regression)
- ❌ More features than samples (use regularization – next lesson!)
- ❌ Complex interactions between features (consider tree-based methods)
Key Takeaways
✓ Linear Model: (\hat{y} = \mathbf{w}^T \mathbf{x}) – prediction is weighted sum of features
✓ Cost Function: MSE measures average squared error
✓ Two Solution Methods:
- Normal Equation: (\mathbf{w}^* = (\mathbf{X}^T \mathbf{X})^{-1} \mathbf{X}^T \mathbf{y}) (exact, one-step)
- Gradient Descent: Iterative optimization (scalable)
✓ Probabilistic View: Minimizing MSE ≡ Maximum Likelihood with Gaussian noise
✓ Interpretation: Coefficients show feature importance and direction of effect
✓ Diagnostics: Check residual plots to validate linear assumption
Practice Problems
Problem 1: Implement Normal Equation
Problem 2: Gradient Descent with Momentum
Enhance gradient descent with momentum: (v_t = \beta v_{t-1} + \alpha \nabla J), then (\mathbf{w} \leftarrow \mathbf{w} - v_t)
Next Steps
You now understand linear regression deeply! Next lessons:
- Lesson 4: Logistic Regression – extending to classification
- Lesson 5: Regularization – handling overfitting and many features
Linear regression is the foundation for many algorithms. Master it, and more complex models will make sense!
Further Reading
Interactive Visualizations
- MLU-Explain: Linear Regression — Amazon's scroll-story with a live model you can fit by dragging points.
- Seeing Theory — Regression Analysis (Brown University) — intuitive OLS visualization with residuals you can drag.
- Setosa: Ordinary Least Squares Regression — interactive scatter-plot where you see the best-fit line update as you add points.
- Google ML Crash Course — Linear Regression — interactive exercises on loss, gradient descent, and learning rate.
- ML Visualized — Gradient Descent on Linear Regression — 3-D error-surface animation synced with the fitted line.
Video Tutorials
- StatQuest — Linear Regression Clearly Explained (Josh Starmer) — the 27-minute "Main Ideas" walkthrough; pair with the least-squares and R² follow-ups.
- 3Blue1Brown — Essence of Linear Algebra, Ch. 15: Abstract Vector Spaces — context for viewing the normal equation as a projection.
Papers & Articles
- A Modern Take on the Bias-Variance Tradeoff in Neural Networks — Neal et al., 2018. Starts from linear regression and shows where the classical picture breaks.
- Regression Shrinkage and Selection via the Lasso — Tibshirani, 1996. The bridge between linear regression and the next lesson.
Documentation & Books
- Book: The Elements of Statistical Learning — Hastie, Tibshirani, Friedman (Chapter 3) — free PDF.
- Book: An Introduction to Statistical Learning (2e, 2021) — James, Witten, Hastie, Tibshirani (free PDF).
- scikit-learn: Linear Models — the canonical API with working examples.
Remember: Linear regression is simple but powerful. Often, the simplest model that works is the best model to use!