K-Means Clustering
Table Of Contents
- Imports
- Create Fake-Data Generator
- Generate Fake Data
- Use K-Means Clustering to discover centroids
- Visualize The Data# K-Means Clustering Example
- unsupervised learning
- split data into "K" groups
- each "K" group revolves around a "centroid"
In [1]:
%matplotlib inline
from numpy import random, array, single, version
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from sklearn.preprocessing import scale
print(f'numpy version {version.version}')In [4]:
#Create fake income/age clusters for N people in k clusters
smallestIncome = 20000
largestIncome = 200000
youngest = 20
oldest = 70
def createClusteredData(N, k):
random.seed(10)
pointsPerCluster = single(N)/k
X = []
for i in range (k):
incomeCentroid = random.uniform(smallestIncome, largestIncome)
ageCentroid = random.uniform(youngest, oldest)
for j in range(int(pointsPerCluster)):
X.append([random.normal(incomeCentroid, 10000.0), random.normal(ageCentroid, 2.0)])
X = array(X)
return XIn [6]:
HOW_MANY_POINTS = 100
HOW_MANY_CLUSTERS = 5
data = createClusteredData(HOW_MANY_POINTS, HOW_MANY_CLUSTERS)Use K-Means Clustering to discover centroids
The number of clusters is typically unknown, so guessing is reasonable.In [12]:
model5k = KMeans(n_clusters=HOW_MANY_CLUSTERS)
model4k = KMeans(n_clusters=4)
model3k = KMeans(n_clusters=3)
model2k = KMeans(n_clusters=2)
# scale & "normalize" the data
fittedModel5k = model5k.fit(scale(data))
fittedModel4k = model4k.fit(scale(data))
fittedModel3k = model3k.fit(scale(data))
fittedModel2k = model2k.fit(scale(data))
# We can look at the clusters each data point was assigned to
print(f'labels: {fittedModel.labels_}')In [27]:
modelLabels5k = fittedModel5k.labels_.astype(single)
modelLabels4k = fittedModel4k.labels_.astype(single)
modelLabels3k = fittedModel3k.labels_.astype(single)
# Create a figure and a set of subplots
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12,4)) # 1 row, 2 columns
ax1.scatter(data[:,0], data[:,1], c=modelLabels5k)
ax2.scatter(data[:,0], data[:,1], c=modelLabels4k)
ax3.scatter(data[:,0], data[:,1], c=modelLabels3k)
plt.show()Page Tags:
python
data-science
jupyter
learning
numpy