K-means clustering is a clustering method that partitions n data points into k clusters such that each data point belongs to the cluster at the shortest distance. It’s a popular unsupervised machine learning algorithm that is widely used in various fields, including data analysis, image processing, and natural language processing.
To recap, the basic step-by-step algorithm for k-means clustering is listed below:
Place k centroids at random locations.
Assign all the data points to the closest centroid.
Compute the new centroids as the mean of all points in the cluster.
Compute the sum of squared errors between new and old centroids.
Note: To read more about the k-means algorithm, check out this link.
We can implement this algorithm using Python by following a series of steps that are outlined below.
To start off, we load a basic two-dimensional dataset in the form of a csv
file via a pandas
dataframe. Next, we display the data points in the form of a scatter plot, as it’s the preferred option to visualize two-dimensional data points. The code to do this is shown below:
#importing necessary librariesimport numpy as npimport pandas as pdfrom matplotlib import pyplot as pltimport math#loading the csv dataset with the read_csv functiondf=pd.read_csv('/mydataset.csv',header=None)#the x-dimension of the datasetx = df[0]#the y-dimension of the datasety = df[1]#displaying each dimension of points with a scatter plot using matplotlibplt.scatter(x, y, c='black', s=20)
From the scatter plot in the previous step, we can see that three clusters are made from the dataset. Therefore, we choose the value of k to be three in this case.
Note: We can optimize this value of k to any value we desire for our problem.
Next, we generate initial centroids for three of the clusters in our dataset via the random.uniform
function:
#import the random function to generate random numbers of a float datatypeimport random#value of k for generating centroid pointsk=3#take the max values of each dimension to generate centroid values#within the datapoint range in the scatter plotmax_x=max(x)max_y=max(y)# X coordinates of random centroidsC_x = [random.uniform(0,max_x) for i in range(k)]# Y coordinates of random centroidsC_y = [random.uniform(0,max_y) for i in range(k)]#print the random centroid points in the form of a 2d arrayC = np.array(list(zip(C_x, C_y)), dtype=np.float32)print("Initial Centroids:")print(C)#displaying the centroids onto the scatter plot with blue stars for clarityplt.scatter(x, y, c='black', s=20)plt.scatter(C_x, C_y, marker='*', s=200, c='b')
Once the three random points are selected as centroids, we display them on the scatter plot we made in the previous step.
Before we dive into the main logic of the solution to the k-means algorithm, we need to implement a few helper functions.
The first function we need is one which implements the Euclidean distance metric we need to find the distance between a centroid and a particular data point in the dataset. The Euclidean distance between two points
This is implemented as:
#implement the Euclidean distance formula called "euclidean"def euclidean(x, y):a=np.subtract(x,y)ans=np.square(a)r1=np.sum(ans)return np.sqrt(r1)
Next, we implement the function that will assign a label to the data point that is nearest to any of the three clusters. Let’s take a small example to understand this function further.
The code that implements this function is shown below.
#function called "assign_members" that assigns labels to data pointsdef assign_members(points, centroids):c1 = [] # cluster 1 containing all points that belong to itc2 = [] # cluster 2 containing all points that belong to itc3 = [] # cluster 3 containing all points that belong to it#initalize the centroid points with respect to cluster 1, 2, 3X=pointscluster_labels=[]c1pt=centroids[0]c2pt=centroids[1]c3pt=centroids[2]#finde the Euclidean distance of the ith point with three of the cluster centroidsfor i in range(len(X)):dist1=euclidean(X[i],c1pt)dist2=euclidean(X[i],c2pt)dist3=euclidean(X[i],c3pt)#the cluster number which had the smallest distance, found by np.argmin, is appended in cluster_labels and that point is added in c1/c2/c3lab=np.argmin([dist1,dist2,dist3])#indices start from zero in an array so labels start from zero!if lab==0: #label 0 corresponding to cluster 1c1.append(X[i])cluster_labels.append(0)elif lab==1: #label 1 corresponding to cluster 2c2.append(X[i])cluster_labels.append(1)else: #label 2 corresponding to cluster 3c3.append(X[i])cluster_labels.append(2)return c1,c2,c3,cluster_labels
Next, we create a update_centroids
function that will take the mean of the data points that have been sorted into each cluster and will return a new centroid point for each cluster.
def update_centroids(cluster1, cluster2, cluster3):#take the mean of all of the points to get the new centroid for c1, c2, c3 clustersnew_c1=np.mean(cluster1,axis=0)new_c2=np.mean(cluster2,axis=0)new_c3=np.mean(cluster3,axis=0)return new_c1,new_c2,new_c3
Finally comes the part where we compute the error between the new centroids made and the old centroids we initialized in the previous step. This is done via the computeError
function, which is shown below.
#compute the squared sum error between the old centroids and new centroids made after updating them#first we take the difference of these points, then square the differences and add them updef computeError(old_centers, new_centers):ans=np.subtract(old_centers,new_centers)return np.sum(np.square(ans))
The formula for computing the sum of squares error is:
Where
Actual values (Y) | Predicted values (Y*) |
120 | 124 |
16 | 13 |
200 | 206 |
In the table with some sample data above, we plug every pair of actual values (
We can deduce that the sum of squares error for the tabulated data is 61.
Lastly, we run multiple iterations of the algorithm in a loop, performing the steps explained above. We stop executing the algorithm when the sum of squares error (error
) becomes zero. Run the code below to see the algorithm in action!
import numpy as npdef computeError(old_centers, new_centers):ans=np.subtract(old_centers,new_centers)return np.sum(np.square(ans))def assign_members(points, centroids):c1 = []c2 = []c3 = []X=pointscluster_labels=[]c1pt=centroids[0]c2pt=centroids[1]c3pt=centroids[2]for i in range(len(X)):dist1=euclidean(X[i],c1pt)dist2=euclidean(X[i],c2pt)dist3=euclidean(X[i],c3pt)lab=np.argmin([dist1,dist2,dist3])if lab==0:c1.append(X[i])cluster_labels.append(0)elif lab==1:c2.append(X[i])cluster_labels.append(1)else:c3.append(X[i])cluster_labels.append(2)return c1,c2,c3,cluster_labelsdef update_centroids(cluster1, cluster2, cluster3):new_c1=np.mean(cluster1,axis=0)new_c2=np.mean(cluster2,axis=0)new_c3=np.mean(cluster3,axis=0)return new_c1,new_c2,new_c3#implementing the euclidean distance formula called "euclidean"def euclidean(x, y):a=np.subtract(x,y)ans=np.square(a)r1=np.sum(ans)return np.sqrt(r1)
Once the algorithm stops executing, we display the scatter plot with the newly formed clusters and their updated centroids, as shown in the code below.
colors = ['r', 'g', 'b']fig, ax = plt.subplots()#plot clusters corresponding to each cluster label (0,1 and 2) one by one in a loopfor i in range(k):#points corresponding to a cluster label are placed in the points array and plotted in each iterationpoints = np.array([X[j] for j in range(len(X)) if cluster_labels[j] == i])ax.scatter(points[:, 0], points[:, 1], s=7, c=colors[i])#new centroid are shown on the newly formed clustersax.scatter(newcentroids[:, 0], newcentroids[:, 1], marker='*', s=200, c='black')
A key takeaway from implementing this algorithm would be that it’s sensitive to the initial position of the centroids. Different initializations may result in different cluster assignments. Thus, it’s common to run the algorithm multiple times with various initial centroids and choose the best result, as shown below:
In addition, k-means can converge to a local minimum (a situation where the algorithm has found a solution that is the best within a small, local region of the solution space but may not necessarily be the globally optimal solution), so it might not always find the best clustering solution. Using alternative methods like hierarchical clustering or DBSCAN can mitigate this issue.
Free Resources