
STA 9890 - Unsupervised Learning: II. Clustering
\[\newcommand{\R}{\mathbb{R}} \newcommand{\E}{\mathbb{E}} \newcommand{\V}{\mathbb{V}} \newcommand{\P}{\mathbb{P}} \newcommand{\C}{\mathbb{C}} \newcommand{\bbeta}{\mathbb{\beta}} \newcommand{\bone}{\mathbb{1}} \newcommand{\bzero}{\mathbb{0}} \newcommand{\ba}{\mathbf{a}} \newcommand{\bb}{\mathbf{b}} \newcommand{\bc}{\mathbf{c}} \newcommand{\br}{\mathbf{r}} \newcommand{\bw}{\mathbf{w}} \newcommand{\bx}{\mathbf{x}} \newcommand{\by}{\mathbf{y}} \newcommand{\bz}{\mathbf{z}} \newcommand{\bX}{\mathbf{X}} \newcommand{\bA}{\mathbf{A}} \newcommand{\bB}{\mathbf{B}} \newcommand{\bC}{\mathbb{C}} \newcommand{\bD}{\mathbb{D}} \newcommand{\bU}{\mathbb{U}} \newcommand{\bV}{\mathbb{V}} \newcommand{\bI}{\mathbb{I}} \newcommand{\bH}{\mathbb{H}} \newcommand{\bW}{\mathbb{W}} \newcommand{\bK}{\mathbb{K}} \newcommand{\argmin}{\text{arg\,min}} \newcommand{\argmax}{\text{arg\,max}} \newcommand{\MSE}{\text{MSE}} \newcommand{\Fcal}{\mathcal{F}} \newcommand{\Dcal}{\mathcal{D}} \newcommand{\Ycal}{\mathcal{Y}} \]
For our first topic in unsupervised learning, we turn to clustering, the task of dividing observations into meaningful groups. Before we get to methods, what are some reasons why we might want to perform clustering on a given data set?
-
Simplification of structure: by dividing a complex population into several smaller and more homogeneous subpopulations, we hope to make it easier to understand some underlying phenomenon.
This is common in things like medical studies: rather than dealing with cancer, as a large and undifferentiated phenomenon, we have divided cancers into types and subtypes of related disease. Note that these are certainly not all identical within a type - the random genetic abnormalities caused by excessive UV exposure in my skin aren’t necessarily the same random genetic abnormalities caused by excessive UV exposure in your skin - but they are still much more homogeneous than two totally different cancers, e.g., pancreatic and brain. By recognizing the subgroups, we have a simpler task to understand four or five groups separately than building one “super-model” that covers them all.
Statistically, you can think of this as being a special case of one of our favorite properties of variance:
\[\V[Y | X ] \leq \V[Y] \text{ for any random variables } X, Y\]
but we do not necessarily know the relevant \(X\) a priori. (Compare this to regression, where we typically have a model - often implicit - about \(X\) and the relationship between \(X\) and \(Y\)).
Data compression: when dealing with large data sets, it may be too expensive to treat all points separately. If we can identify repeated (or nearly repeated) points, we combine them and add a weight of “2” to our algorithm, earning us some savings in computational time and in memory usage.
Denoising: when faced with noisy observations, we might get a better of the ‘true’ underlying signal by averaging several similar observations. Obviously, we don’t want to accidentally mix in signals that are fundamentally different, but slight differences might be ok in a “bias-variance” sort of way.
In each of these tasks, we’re faced with a bit of a trade-off in our task:
- If we divide our data into more clusters, we have more homogeneity within the groups, but we get less of the benefits of clustering because we still have more groups to deal with.
- If we divide our data into fewer clusters, we get more simplicity (fewer clusters) but our simplified (clustered) representation is lossier.
In our discussion of clustering, this is the principal “tuning parameter” we will have to wrestle with. (Sometimes this will be an explicit “how many clusters?” hyperparameter, while other times we will have a “how much combining?” parameter which will only implicitly determine the number of clusters when we apply it to a data set.) As discussed last week, there’s not a simple “just holdout”-type heuristic we can use to assess the performance of clustering and our validation strategies typically need to be a bit more problem specific.
If your regression training had some an economics-flavored background, you may have heard the term clustering used in a regression context, referring to groups of related observations that should be not modeled as independent observations (e.g., standardized test performance of different classes within the same school). This is not that type of clustering!
In that supervised context, we know the underlying group structure and use it to inform our regression modeling. In the unsupervised context we consider, we do not know the underlying group structure (though we assume one exists!) and we instead seek to discover it.
Distance-Based Clustering
We will build up our basic clustering framework from a distance-based perspective. If two points \(\bx_1, \bx_2\) are relatively close (in a distance sense), they are more suitable candidates for clustering than two points \(\bx_3, \bx_4\) which are very far apart. This will not always be the case - if \(\bx_4\) is very far away from anything else, \(\bx_3\) may still be its “nearest-neighbor” - but it gives us a good place to start.
We will begin with clustering based on Euclidean (\(\ell_2\)) distances, but most of the methods we consider can be “kernelized” - allowing for other distances in a feature-expanded space - or “sparsified” - allowing the method to identify and cluster based only on distances in a few relevant features. Because we start with \(\ell_2\) distance, these methods are going to perform their best when applied to data with nice isotropic (circular) structure.
This type of data will be relatively easy to cluster:
Of course, we do not generally get our data so nicely labeled and instead are faced with something like:

While the cluster structure is less clear, a first pass would lead to placing a circle around the
Clearly, if you were to cluster this data, drawing a circle around the group on the left and the group on the right:

Clearly not a perfect clustering - we have a few uncircled points to deal with and the choice of the radius is debatable - but a darn good first start. Notably, the points in each circle are generally closer to the center of the circle (and hence to each other) than they are to points in the other circle.
Compare this to a situation in which there is strong correlation among the two coordinates, even if the two clusters are still centered in the same place.

The same “draw a circle” strategy gives us

This still isn’t a terrible outcome - it’s just not that hard a problem! - but it’s clearly harder than the previous problem, even though the clusters themselves aren’t any further apart and the variance in each direction isn’t changed. If we were able to represent this data in a more ‘diagonal’ way that captured the structures of each cluster, we might do better. (Hint, hint: PCA)
With that aside complete, our basic clustering problem is thus:
- Divide all points into one of \(K\) groups
- Minimize the variance within the groups
- Maximize the distance between the groups
Here, we can parameterize each group by its center point (mean), which is known as the centroid.1 We can modify the above to:
- Divide all points into one of \(K\) groups
- Minimize the distance from each point to its nearest centroid
- Maximize the distance between the centroids
So let’s write this as an optimization problem:
\[\argmin_{\substack{c_1, c_2, \dots, c_n \in \{1, 2, \dots, K}\} \\ C_1, C_2, \dots, C_K \in \R^p} \sum_i \|X_i - C_{c_i}\|_2^2\]
That’s a mess! But what does it mean?
- We want to find labels \(c_1, \dots, c_n\). We have one label for each of our \(n\) points and, rather than being arbitrary numbers, we only get to pick the integers \(1\) to \(K\).
- Additionally, we want to estimate \(K\) centroids \(C_1, \dots, C_K\).
- Our objective function is the sum of squared distances from each point \(X_i\) to its corresponding centroid \(C_{c_i}\). For instance, if our third point was mapped to centroid \(2\), the corresponding term in the objective \(\|X_3 - C_2\|_2^2\)
So not too complex once we take it apart, but we will have trouble solving it. Due to the discrete labels \(\{c_i\}\), this is a combinatorial (discrete) problem and we don’t get the convexity magic.
Our first two clustering methods are based on on approximately solving this problem.
The \(K\)-Means Model
Returning to our problem above, while it is hard to solve the problem as written, we can break it into two relatively simple problems:
- If we know the centroids, it’s easy to determine the closest centroid for each point
- If we know the cluster labels, we can estimate the corresponding centroids by averaging all the points in the cluster.
Putting these together, we get a pretty great clustering algorithm. We start with a guess of either the labels and then we iteratively update the centroids and labels, using our answer from the previous label update to get better centroids and then using our latest centroids to get better labels. We run this process until the labels stop changing (which implies the centroids stop changing) and we have a pretty good guess as to our clustering.
Let’s see this in action. I’m going to generate data with three clusters. These will be visually apparent so you can see the algorithm in action.
D <- data.frame(c = rep(c("A", "B", "C"), each=40)) |>
mutate(r = 1,
theta = recode_values(c, "A" ~ 0, "B" ~ 2*pi/3, "C" ~ 4*pi/3),
z_mean = r * exp((0 + 1i) * theta),
x_mean = Re(z_mean),
y_mean = Im(z_mean),
x = rnorm(n(), mean=x_mean, sd=0.75),
y = rnorm(n(), mean=y_mean, sd=0.75),
id = row_number())
ggplot(D, aes(x=x, y=y)) +
geom_point() +
theme_bw()
We see three pretty clear clusters, with a few possibly ambiguous points. To start our algorithm, let’s randomly assign each point to one of three clusters, indicated by colors:

Obviously, by doing this randomly, we’re not going to do well. But with these clusters, we can now estimate cluster centroids by taking the average of all points of that color:
D <- D |>
group_by(c_hat) |>
mutate(x_centroid = mean(x),
y_centroid = mean(y)) |>
ungroup()
ggplot(D) +
geom_point(aes(x=x, y=y, color=c_hat), alpha=0.5) +
geom_point(aes(x=x_centroid, y=y_centroid, color=c_hat), pch=12, size=3) +
theme_bw() +
scale_color_brewer(type="qual",
palette=2,
name="Initial Guessed Clusters")
With these, we get to update our cluster estimates:
D1 <- D |> select(-c_hat, -x_centroid, -y_centroid)
D2 <- D |> select(c_hat, x_centroid, y_centroid) |> distinct()
D <- cross_join(D1, D2) |>
mutate(dist_to_centroid = (x - x_centroid)^2 + (y - y_centroid)^2) |>
group_by(id) |>
slice_min(dist_to_centroid) |>
ungroup() |>
select(-ends_with("centroid"))
ggplot(D) +
geom_point(aes(x=x, y=y, color=c_hat)) +
theme_bw() +
scale_color_brewer(type="qual",
palette=2,
name="Guessed Clusters - 1 Round")
Not perfect, but much improved! From here, we again get new estimated centroids and plot them:
D <- D |>
group_by(c_hat) |>
mutate(x_centroid = mean(x),
y_centroid = mean(y)) |>
ungroup()
ggplot(D) +
geom_point(aes(x=x, y=y, color=c_hat), alpha=0.5) +
geom_point(aes(x=x_centroid, y=y_centroid, color=c_hat), pch=12, size=3) +
theme_bw() +
scale_color_brewer(type="qual",
palette=2,
name="Guessed Clusters - 1 Round")
We can repeat this process a few times and see that the algorithm converges quickly:
D1 <- D |> select(-c_hat, -x_centroid, -y_centroid)
D2 <- D |> select(c_hat, x_centroid, y_centroid) |> distinct()
D <- cross_join(D1, D2) |>
mutate(dist_to_centroid = (x - x_centroid)^2 + (y - y_centroid)^2) |>
group_by(id) |>
slice_min(dist_to_centroid) |>
ungroup() |>
select(-ends_with("centroid")) |>
group_by(c_hat) |>
mutate(x_centroid = mean(x),
y_centroid = mean(y)) |>
ungroup()
ggplot(D) +
geom_point(aes(x=x, y=y, color=c_hat), alpha=0.5) +
geom_point(aes(x=x_centroid, y=y_centroid, color=c_hat), pch=12, size=3) +
theme_bw() +
scale_color_brewer(type="qual",
palette=2,
name="Guessed Clusters - 2 Rounds")
D1 <- D |> select(-c_hat, -x_centroid, -y_centroid)
D2 <- D |> select(c_hat, x_centroid, y_centroid) |> distinct()
D <- cross_join(D1, D2) |>
mutate(dist_to_centroid = (x - x_centroid)^2 + (y - y_centroid)^2) |>
group_by(id) |>
slice_min(dist_to_centroid) |>
ungroup() |>
select(-ends_with("centroid")) |>
group_by(c_hat) |>
mutate(x_centroid = mean(x),
y_centroid = mean(y)) |>
ungroup()
ggplot(D) +
geom_point(aes(x=x, y=y, color=c_hat), alpha=0.5) +
geom_point(aes(x=x_centroid, y=y_centroid, color=c_hat), pch=12, size=3) +
theme_bw() +
scale_color_brewer(type="qual",
palette=2,
name="Guessed Clusters - 3 Rounds")
D1 <- D |> select(-c_hat, -x_centroid, -y_centroid)
D2 <- D |> select(c_hat, x_centroid, y_centroid) |> distinct()
D <- cross_join(D1, D2) |>
mutate(dist_to_centroid = (x - x_centroid)^2 + (y - y_centroid)^2) |>
group_by(id) |>
slice_min(dist_to_centroid) |>
ungroup() |>
select(-ends_with("centroid")) |>
group_by(c_hat) |>
mutate(x_centroid = mean(x),
y_centroid = mean(y)) |>
ungroup()
ggplot(D) +
geom_point(aes(x=x, y=y, color=c_hat), alpha=0.5) +
geom_point(aes(x=x_centroid, y=y_centroid, color=c_hat), pch=12, size=3) +
theme_bw() +
scale_color_brewer(type="qual",
palette=2,
name="Guessed Clusters - 4 Rounds")
At this point, the algorithm is terminated (since repeating the updates won’t change anything) and we actually have a pretty good result. You could quibble with some points on the boundary, but this is actually pretty impressive if you think back to our starting point just a few iterations earlier.
The algorithm is known as Lloyd’s Algorithm and it solves the so-called “\(K\)-Means” problem (of estimating \(K\) cluster means). Despite its simplicity, it usually converges quite quickly to a reasonable solution. In R, we implement this via the kmeans() function:
Which we visualize as before:
D1 <- D |> select(-c_hat, -x_centroid, -y_centroid)
D |>
select(x, y) |>
as.matrix() |>
kmeans(centers=3) |>
useful::fortify.kmeans() |>
cross_join(D1) |>
mutate(dist_to_centroid = (x - .x)^2 + (y - .y)^2) |>
group_by(id) |>
slice_min(dist_to_centroid) |>
ungroup() |>
select(-ends_with("centroid")) |>
ggplot(aes(color=.Cluster)) +
geom_point(aes(x=x, y=y), alpha=0.5) +
geom_point(aes(x=.x, y=.y), size=3, pch=10) +
theme_bw() +
scale_color_brewer(type="qual",
palette=2,
name="Guessed Clusters - kmeans()")
Which is quite nice!
But this actually doesn’t quite match our results from above. What happened?
Lloyd’s Algorithm is sensitive to initialization - with different starting values, we get different final results. (Recall that this can’t happen with convex problems!) To get around this, we typically run \(K\)-means several times with different initial values and pick the best result (defined as the minimum within cluster variance).
D1 <- D |> select(-c_hat, -x_centroid, -y_centroid)
D |>
select(x, y) |>
as.matrix() |>
kmeans(centers=3, nstart = 100) |>
useful::fortify.kmeans() |>
cross_join(D1) |>
mutate(dist_to_centroid = (x - .x)^2 + (y - .y)^2) |>
group_by(id) |>
slice_min(dist_to_centroid) |>
ungroup() |>
select(-ends_with("centroid")) |>
ggplot(aes(color=.Cluster)) +
geom_point(aes(x=x, y=y), alpha=0.5) +
geom_point(aes(x=.x, y=.y), size=3, pch=10) +
theme_bw() +
scale_color_brewer(type="qual",
palette=2,
name="Guessed Clusters - kmeans(nstart=100)")
Here, the data is small enough and the algorithm is fast enough we could run it many times and get an improved result.
Note that there are two ways in which the results can differ:
- A “label switching” problem in which cluster “A” is called cluster “B”, but the actual groupings don’t change
- Different groupings, where two points are in the same group in one solution but different groups in a second solution
The first makes for some annoying book-keeping, but is ultimately harmless. The latter is what we worry about when using \(K\)-means.
- If we initialize our centers
An EM Algorithm
Convex Clustering
Spectral Clustering
Hierarchical Clustering
Density Based Clustering
Validation of Cluster Results
Footnotes
You might ask whether it is fair to swap “distance from point A to point B” for “distance from A to the midpoint of A, B”. It turns out that, for Euclidean distance, this is basically fine. You might remember an alternate definition of variance we sometimes use for random variables: \(\E[(X - X')^2] = 2\sigma^2\) where \(X, X'\) are independent samples from a distribution with variance \(\sigma^2\). We’re applying essentially the same argument here.↩︎