IceColdCode

Linear Regression Using Matrix Notation

Hello, I just finished my second cup of coffee and figured I’d write down a quick introduction to linear regression before I forget it. Linear regression is a method for fitting a linear function to a set of data points. A simple linear function is:

In order to keep the introduction brief I will stick to this simple case. It often helps to look at a graph to be completely clear with what we are trying to do. Consider the following graph:

Fig 1. Two dots that we want to fit a line through

We want to find a function such that it goes through the dots, or as close as possible to the dots. There are multiple ways to do this. From school you probably remember:

Which gives us the following graph:

Fig 2. The same two dots as in fig 1. with a line given by “y = x + 1” passing through them.

This method works for a simple example like this. But consider the case where we have many points and there is no perfect solution, no function fitting all the dots. We can still do something, namely find a function such that its straight line passes as closely as possible to all the points. We settle for the best solution rather than perfect. Let’s again look at a graph to visualize what we are trying to solve.

Fig 3. Dots that don’t lie on a straight line. This could be due to noise.

Given these dots, we want to find a straight line that passes as closely as possible through them. To do this we need to define what as closely as possible means. One common definition is the squared distance from our line to the dot at a given x-value. Squared since we don’t care about whether the distance is negative or positive (whether the line passes over or under the true value). The important bit is that we will get a higher number when the distance from our line to the dot is large. We want to minimize this number for all dots, however, since we have just one line we can’t do this. But we can minimize the sum of this number for all dots.

Minimizing the Loss Function

The sum of the distances from our line to the dots is called a loss function. To find the straight line passing the dots as closely as possible we need to minimize this loss function. There are two common methods to minimize it. One analytical based on calculating the derivatives, setting them to zero and solving the equation. Another is to use a numerical technique called gradient descent. I will focus on the analytical solution here. It is not very difficult to derive it. Check the further reading section if you are interested. Let’s now jump straight to the solution, in vector notation it is quite beautiful.

Equipped with this, let’s find the parameters, meaning calculate for the data points in fig 3 and plot a straight line with these parameters to verify our results. The following Octave script can be used for this:


function result = findOptimalParamters (X, t)
  result = ((X' * X)^-1) * X' * t;
endfunction

X = [1 1; 1 2; 1 3; 1 4; 1 5];
t = [1 3 1 5 7]';

opt = findOptimalParamters(X, t)

Which gives us the following graph:

Fig 4. Showing the same dots as in fig 3 and a function “y = -0.8 + 1.4x”.

Note the conciseness of the solution. is what we get when calculating the derivative of the loss function expressed in matrix notation and setting it to zero.

Further Reading