Unsupervised Machine Learning with JavaScript
Older Article
This article was published 8 years ago. Some information may be outdated or no longer applicable.
In a previous article we looked at supervised machine learning for natural language processing on documents. This time we’re tackling an unsupervised technique called k-NN.
k-NN
k-NN (also written KNN or knn) stands for “k nearest neighbours”. It’s a clustering classification system that groups data without being told what the groups should look like. In classic supervised machine learning, you feed the system training data so it can learn patterns and predict future values.
k-NN takes a different approach. It clusters data in 4 steps. First, decide how many clusters you want. Let’s say 2 for now. (We’ll cover how to pick the right number later.)
Assume you’ve got a coordinate system with your data points plotted on it.
For each cluster, drop a random point onto the coordinate system. Measure the distance between every data point and these random points (called centroids). Assign each data point to its nearest centroid.
Now shift each centroid so it sits in the middle of its cluster.
Repeat: reassign each data point to its closest centroid. Keep going until the centroids stop moving and no data points need reassigning.
Once nothing changes, the algorithm’s finished. It’s created two clusters and classified the data without any instructions.
Unsupervised machine learning
With unsupervised machine learning, you hand the algorithm raw data and it figures out the classification on its own. Later, you can throw new data points at it and it’ll tell you where they belong based on what it’s already learnt.
Here’s a simple example. Imagine you work at a bank as a data scientist. You’ve got access to monthly salaries and monthly credit card spending for a bunch of customers. You could sort these people into categories: those who earn less but spend more (watch out, they might struggle with repayments) versus those who earn a lot and spend a lot (target them with a marketing campaign).
We’ll use k-NN clustering to create these labels.
How many clusters?
You can figure out the right number of clusters using the “elbow method”. Getting this right matters because k-NN will obediently split data into however many clusters you tell it to, even if that number makes no sense. But there’s a way to find the sweet spot.
The idea: iterate through a range of cluster counts (k = [1, 10], for example) and calculate the sum of squared errors for each value of k. Plot those sums on a line chart and you’ll see something shaped like an arm. Look for the elbow, the point where the rate of change drops off.
Calculating the sum of squared errors is simple. Take some salary values: 1500, 1510, 1700, 1400, 1600, 1455. Calculate the mean: (1500 + 1510 + 1700 + 1400 + 1600 + 1455) / 6 = 1527.5. Then calculate how much each value deviates from the mean: 1500 - 1527.5 = 27.5, 1510 - 1527.5 = 17.5, etc. Square those deviations: 27.5 squared, 17.5 squared, etc. Add them all up. That’s your SSE (sum of squared errors).
The chart above shows the elbow method applied to our dataset. The sweet spot lands at 4, meaning we need 4 clusters. In maths terms: K = 4.
The data
Here’s a sample document from our dataset:
(All data is auto-generated. Everyone’s fictitious.)
{
"name": "Allison Stokes",
"gender": "female",
"email": "allisonstokes@zilladyne.com",
"phone": "+1 (869) 597-2480",
"address": "936 Berriman Street, Zarephath, Alaska, 2799",
"salary": 4143,
"creditCardSpend": 7193
}
We’ll pull the salary and creditCardSpend values and feed them to the clustering algorithm.
The application’s architecture
The JSON documents live in a NoSQL database, so we need a connector to retrieve them. Since we’re using MarkLogic, we’ll use the MarkLogic Node.js Client API to connect and query.
We’ll also run Express as a web server, create an API endpoint that returns the clustered data, and serve an index.html file where we plot the clusters on a chart.
For the chart itself, we’ll use Google’s scatter charts.
Use k-NN from Node.js
Several npm packages implement the k-NN algorithm. We’ll use clusters. It’s easy to work with and produces clear output.
The code
First, query the database and feed the results to the clustering algorithm. Remember, we’ve decided on 4 clusters. The /api endpoint handles the query, runs the algorithm, and returns the cluster data as its response.
// app.js - code snippet
const clusterMaker = require('clusters');
clusterMaker.k(4);
clusterMaker.iterations(1000);
app.get('/api', (req, res) => {
db.documents
.query(
qb.where(qb.directory('/client/')).slice(0, 500) //take 500 documents as samples
)
.result()
.then((documents) => {
const response = documents.map((document) => {
return [document.content.salary, document.content.creditCardSpend];
});
clusterMaker.data(response);
const clusters = clusterMaker.clusters();
res.json(clusters);
});
});
The clusters variable’s structure is easy to read:
[
{
centroid: [3283, 2767.5],
points: [
[3107, 2563],
[3154, 2453],
[3043, 2179],
[3828, 3875],
],
},
{
centroid: [4765.5, 2444],
points: [
[4651, 2471],
[4880, 2417],
],
},
{ centroid: [6579, 2079], points: [[6579, 2079]] },
{
centroid: [6209.666666666667, 5707.333333333333],
points: [
[6644, 6402],
[6083, 5238],
[5902, 5482],
],
},
];
Four clusters, four centroids, and a list of points belonging to each cluster. Now we take this data and plot it on a scatter chart:
$.get('/api', (apiData) => {
google.charts.load('current', { packages: ['corechart'] });
google.charts.setOnLoadCallback(drawChart);
function drawChart() {
const chartData = [];
chartData.push([
'Salary (£ pcm)',
'Credit Card Spend (£ pcm)',
{ type: 'string', role: 'style' },
]);
apiData.map((elements, index) => {
let colour;
if (index === 0) {
colour = 'red';
} else if (index === 1) {
colour = 'green';
} else if (index === 2) {
colour = 'orange';
} else {
colour = 'blue';
}
elements.centroid.push(
`point { size: 8; shape-type: circle; fill-color: ${colour} }`
);
chartData.push(elements.centroid);
return elements.points.map((point) => {
point.push(`point { fill-color: ${colour}`);
return chartData.push(point);
});
});
const data = google.visualization.arrayToDataTable(chartData);
const options = {
title: 'Montly Salary vs Monthly Credit Card spend',
hAxis: { title: 'Salary (£ pcm)' },
vAxis: { title: 'Credit Card Spend (£ pcm)' },
legend: 'none',
pointSize: 2,
hAxis: {
minValue: 3000,
viewWindow: {
min: 2800,
},
},
};
const chart = new google.visualization.ScatterChart(
document.getElementById('chart_div')
);
chart.draw(data, options);
}
});
And here’s the finished product:
Four clusters on the scatter chart, with the larger dots marking the centroids.
From here, you’d act on the clustered data. People in the blue cluster could be targeted with direct marketing campaigns to nudge a purchase. People in the green cluster might need a gentle warning about overspending.