Recently, I conducted a session on Python where I walked through implementing a kNN classifier. Close to the end of the session, we got to how succinct Python can be, and I proceeded to reduce our code to the absolute minimum number of lines possible. The impromptu code-golfing exercise led me to an interesting realization - you can write a kNN classifier
in one line of Python. A line short enough (
126 characters) to
fit into a tweet!
Yep, I checked:
 |
| Fig 1. A tweet-sized classifier. |
In retrospect, this does not seem surprising given certain features of the language. Its just one of those things you don't normally think about; you wouldn't want to implement an algorithm in one, potentially confusing, line. Frankly, you
shouldn't. In the interest of the sanity of the poor souls who might need to read your code someday.
But notwithstanding coding etiquette, we are going to take a look at this since a one-liner k-NN is, well, pretty awesome!
In the next couple of sections I provide some context to the problem and the solution. I have made these sections collapsible (doesn't seem to work with the mobile version of the site), since they are not the focus of the post. If you already know what a
kNN classifier is, and understand
Python maps/lambdas,
skip ahead to the section "Initial setup".
What is a kNN classifier?
(
show)
A
classifier looks at a data point provided to it and declares it to belong to a particular
class. A common example of this is a
spam classifier, which scans through the text of a mail and declares it to be belonging to one of these classes:
spam or
no-spam. If you use Gmail, you have already seen a very powerful spam classifier in action - using a classifier is how Gmail
knows what goes into your spam folder.
A classifier is initially presented with data points
along with their correct labels - so that it can identify associative patterns between the data points and the labels. This dataset is known as
training data. For ex, in the case of spam, a classifier might note that words like
free,
lottery etc
. strongly correlate to a mail being a spam mail. Once the training phase is done, the classifier is ready to tag unseen instances, known as
test data points, . Beyond the training phase, the classifier may not need to keep the training data set around - instead, it can solely rely on using the patterns it has learned.
The
kNN classifier, however, is a
lazy classifier - it does not bother with learning associative patterns. It keeps
all of the training dataset around (at least, the simplest avatar of the classifier does so) , and for a test point, finds points similar to it in the training dataset, and outputs the dominant label among the similar points. This is where it derives its name from - its classification is based on the labels of the
k nearest
neighbours of the test point.
I can almost describe the philosophy of the kNN classifier with this pithy quote:
You are the company you keep.
In machine learning, a data point is typically described as a
vector of
features (also known as
attributes or
dimensions). For ex, to describe a mail, the features of interest might be the the presence or absence of certain words; these could be my features:
- has the word free
- has the word lottery
- has the word weekend
And a mail described by the vector
[1,1,0], would have the words
free and
lottery in it - indicated by the
1s - but would not have the word
weekend, as indicated by the
0.
The number of features is termed as the
dimensionality of the data. In the above toy example, we were dealing with 3-dimensional vectors.
Given that every data point is depicted by a n-dimensional vector, we can plot the point a n-dimensional coordinate system. Yes, we can't see stuff beyond 3-dimensions, but in theory, we can think of a data point being plotted so. This helps us visualize what a kNN does. In the following figure the coloured points belong to the training dataset - some have the label red, some have green. We want to find the label for the test point, shown with a "?". If we assume k = 3 (smaller circle), the dominant label is red (red to green:: 2:1), and that is what we declare. If k = 5, green dominates (red to green:: 2:3) and that is what we declare as the result.
 |
| Fig 2. The kNN classifier |
The above figure helps make a significant point - the accuracy of a kNN classifier depends on
k. We won't worry about the right value of
k in this post - we'll assume
k is provided to us.
Note that, in the figure above, we implicitly assumed that the neighbours of a points are defined by the
Euclidean distance - the distance one gets between 2 points when measuring with a ruler. But a different way of measuring distances/similarities may make sense for a problem.
So what does the algorithm for kNN classification look like? A naive implementation would have the following steps:
- Calculate the similarity of the test point with each of the training points
- Get the labels of the k most similar points
- Declare the label of the test point to be the majority vote label
The one-liner I am interested must implement the above steps.
A quick primer on maps and lambdas
(
show)
map() is one of the key constructs that make the one-liner possible. This is what a simple map() call looks like:
map( function, iterable)
The map() call takes two arguments - a function as the first argument, and an iterable as its second argument. An iterable is anything that can be iterated on, like a list; for our purposes we would exclusively consider lists. As the name might suggest, map() creates a mapping of all items in the provided list using the function passed in as its first argument - the mapped version of the elements are returned as a new list.
As an example, consider the list of integers a = [4, 1, 2]. Consider the function call: map(str, a). Here, I am passing the in-built function "str()" - which converts its argument into a string - as the first argument. This returns a new list ['4', '1', '2']. This consists of mapped versions of elements in the old list - string representations of the integers. Note that we pass in the function name as the first argument - str - and not str(). The latter would just pass the result of the function call (a blank string in this case).
While map() is a handy construct in itself, it seems that the first argument to it - the mapping function - must be defined elsewhere. For ex, in the above example, we used str(), which is defined in the standard libraries. Fortunately, it turns out that this is not the only possibility, thanks to
lambda expressions.
A simple lambda expression looks like this:
lambda [parameter list]: function body
Think of a
lambda as a function that is defined on the fly. The parameter list specifies the parameters that go into the function, and the result of executing the function body (specified on the
same line) is returned as a result of the expression. Again, for a detailed explanation,
consult the docs.
For ex, the line
lambda x: x*x defines a function that takes one argument
x, and returns its product with itself. Note that the lambda does not have a function name (unless you assign a lambda expression to a variable). Thus, its a "use-and-throw" function - its exists where it is defined, and since it has no name, it cannot be called later. The reason why its so valuable is that it can be used as a valid function, including in calls to map().
Consider, again, the list of integers
a = [4, 1, 2]. And consider this map call:
map(lambda x: x*x, a). This returns
[16, 1, 4] since the lambda expression serves as the mapping function.
There are a bunch of functions in Python, other than map(), that accept functions as arguments and using lambdas with these can often help you write elegant and concise code. For ex, the function
filter() has a function signature similar to map():
filter( function, iterable)
The output of the filter is a list that contains only those values from the original list (its second argument) that return a non-null value when the supplied function (the first argument) is called on them. For ex,
filter(lambda x: x%2==0, a) would return
[4, 2] - elements from the original list that are divisible by 2 - since the lambda expression returns
True for even integers. We will meet yet another such function,
max(), while analysing the one-liner.
I will assume that these variables are defined for us:
- train - is a list with the training data. Every element in this list is a tuple with the first entry as the feature vector and the second entry as the label. So, if I had the following 4-dimensional data-points in my training data:
- [2, 4, 3, 5], label = 0
- [5, 7, 4, 7], label = 1
- [6, 8, 2, 3], label = 1
- [5, 7, 2, 4], label = 0
then train = [([2, 4, 3, 5], 0), ([5, 7, 4, 7], 1), ([6, 8, 2,3], 1), ([5, 7, 2, 4], 0)].
- sim(x, y) - the similarity function. Takes two data-points as arguments and returns a numeric value for similarity. For our example, we use the dot-product as a measure of similarity. The dot-product of two vectors [a, b, c] and [p, q, r] is a*p + b*q + c*r.
- k - the number of nearest neighbours we need the labels from. For our example, we will assume k = 3.
- test - the test data point, which we are interested in classifying. This is a list of the feature values. We use [1, 2, 3, 4] as an example test point.
- I am also assuming that the Python collections module is imported into the current namespace i.e. this statement "from collections import *" has been executed (I know this sounds like cheating, but hold on judging me till you have read the next section ...)
Note that 1-4 are external to the classifier and have to be provided anyway - I am just fixing the variable names and structures.
The One (line ...)
And finally, [drum roll] the magic incantation ...
max(Counter(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:])).items(), key = lambda x:x[1])[0]
If I were to decompose the above line into logical steps, this is how I would go about it:
- max(Counter(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:])).items(), key = lambda x:x[1])[0]
- max(Counter(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:])).items(), key = lambda x:x[1])[0]
- max(Counter(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:])).items(), key = lambda x:x[1])[0]
- max(Counter(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:])).items(), key = lambda x:x[1])[0]
- max(Counter(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:])).items(), key = lambda x:x[1])[0]
Explanation
Step 1: Get all similarities, keep the labels around
max(Counter(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:])).items(), key = lambda x:x[1])[0]
In the first step, we call a map with the training data. The lambda expression returns a
tuple with the first entry
sim(x[0], test) and the second entry as
x[1]. Consider the fact that
train is a list of tuples itself, and thus the argument to the lambda is a tuple. For ex the first time the lambda function is called is with the argument
([2, 4, 3, 5], 0). Thus, for the lambda,
x[0] is
[2, 4, 3, 5] and
x[1] is
0. Hence, the first entry of the returned tuple is the similarity of the training data point to the
test data point, while the second entry is the label of the training data point.
Given the sample
train and
sim, this first step produces the following list:
[(39, 0), (59, 1), (40, 1), (41, 0)]
As a side-note, it is quite interesting that the dot-product can itself be written as a lambda using the
zip() function:
lambda list_1, list_2: sum(map(lambda x: x[0]*x[1], zip(list_1, list_2)))
Step 2: Find the top k points
max(Counter(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:])).items(), key = lambda x:x[1])[0]
We now sort whatever we got out of the last step in decreasing order of similarity. We use the
sorted() function for this purpose. Since it sorts tuples, and we have not mentioned a key, it sorts on the first index of the tuple - the
similarity to the
test point. By default, sort happens in increasing order, so the output so far looks like this:
[(39, 0), (40, 1), (41, 0), (59, 1)]
You can sort in decreasing order too - by passing in the argument
reverse=True, but that's more keystrokes :)
Now, we extract the tuples with the highest
k similarities using
[-k:]. The "-" in the list indexing notation in Python implies the counting is done from the tail of the list - here, we get the portion of the list starting with the
kth element from the tail of the list till the last element. This gives us the tuples with the
k highest similarities. Thus, this is the output of this step:
[(40, 1), (41, 0), (59,1)]
Step 3: Get the labels for these top k points
max(Counter(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:])).items(), key = lambda x:x[1])[0]
Since we have the tuples corresponding to the top-k similar points, we really don't need to keep the similarities around any more. Another map() call does the trick - look at the lambda expression here: it only returns the value at index 1 of the tuple. Therefore, as output from this step, we have just the labels:
[1, 0, 1]
Step 4: Get label frequencies
max(Counter(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:])).items(), key = lambda x:x[1])[0]
We now need to find the label with the highest frequency of occurrence in the k closest points. We use the
Counter data-structure from the
collections module. The Counter is a special kind of
dict which, when initialized with a list, uses the unique elements in the list as keys and the frequency of their occurrence as the corresponding value.
Thus,
Counter([1,1,2,3,2,2,1]) gives me the dict
{1: 3, 2: 3, 3: 1}. This is exactly what we need. Initializing a Counter with the output of the last step -
[1, 0, 1] - gives us this dict:
{1: 2, 0: 1}
Calling
items() on a dict returns a list of tuples, each of whose 1st entry is a key, and the 2nd entry is the corresponding value. Hence, the output of this step is:
[(0, 1), (1, 2)]
Time to address the elephant in the room, I guess :) You could argue that using something from a module does not make for a strict one-liner. And in a sense that is true. I
could have avoided using Counter, and instead, used tricks like
these. Technically, we would have still ended up with an one-liner, but Counter expresses our objective in a clear manner. Thus, not using Counter
does not mean that you cannot have an one-liner, only that the result won't be pretty.
Note that since I am calling Counter directly, my import should be "
from collections import *". If you want a cleaner namespace, you might want to do "
import collections" and refer to Counter as "
collections.Counter"
Really close now ...
Step 5: Get the most common label
max(Counter(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:])).items(), key = lambda x:x[1])[0]
All that is left now is to extract the label with the highest frequency. That's easy with the
max() function in Python. Note that our tuples (output from the last step) have the label as the first entry and the frequency of the label as the second entry. Hence, in the max() call, we provide the
key argument which finds the maximum valued element in the list
based on the second entries in the tuples. Yes, with a lambda expression. They really do get around. Thus, max() returns me this:
(1,2)
Of which, I extract the label by specifying the index -
[0]. This gives me my final answer:
1
And this ... is the label we are looking for [melodramatic bow] :)
Regression and a related hack
While I was writing this post, I realized that the one-liner for kNN based regression is actually shorter - 92 characters.
Also, purer - does not rely on modules that need importing. In a
regression problem we are interested in predicting continuous values -
as against discrete labels in classification problems. For ex you may
want to determine the price of a piece of real estate based on features like proximity to a hospital, distance from an airport, proximity to industries etc. The training data has prices per data point instead of labels.
In kNN based regression, you would look at the values of the k neighbours and return their average.
Here goes:
sum(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:]))/(1.0*k)
If
you understand the classifier, understanding this should be easy. The
portion in black is the same as the corresponding part in the
classifier. The part in
red is specific to
regression. Note that I need multiply "1.0" to the denominator to make
sure Python does not perform an integer round-off in the division.
Now, here is an interesting relationship between kNN regression and classification - if you have only two labels,
0 and
1,
then if you average the labels and round-off, it is equivalent to majority voting. For example, consider some sample outputs from Step 3 above:
| Top k labels | Avg of labels | Avg rounded | Majority vote |
| 1, 0, 1 | 0.67 | 1 | 1 |
| 0,0,1,0 | 0.25 | 0 | 0 |
| 1,1,0,0,1 | 0.6 | 1 | 1 |
This does not work with more than two labels.
Thus, as long as the labels in the classification problems are the numbers
0 and
1 , we can use this following expression:
int(round(sum(map(lambda x: x[1], sorted(map(lambda x: (sim(x[0], test), x[1]), train))[-k:]))/(1.0*k)))
The portion in black is the expression for regression we saw earlier. The part in
red is needed for classification.
round() is a built-in in Python which does not need us to import anything (yay!); however, since it returns a float, we explicitly need to cast to
int (bummer...).
This version of the one-line classifier is 104 characters long.
Conclusions
A final tally of character counts in the various one-liners. Additionally, mentioned in parentheses, are the sizes we would see if we weren't worried about readability i.e. one letter variable names, no spaces:
- kNN classification - 126 characters (107)
- kNN regression - 92 characters (77)
- kNN classification hack - 104 characters (89)
Pretty impressive numbers - my respect for Python just went up a notch!
I am pretty sure these are not the only implementations of the one-liners possible. For starters, you might favour
list comprehensions over map(), do away with the
Counter to create a 'purer' one line etc. If you have something interesting, let me know in a comment!
Although, great as an intellectual challenge, I don't think the one-liner is suitable for serious use. Good classification libraries may have optimized implementations for kNN - for ex using
kd trees or
ball trees - which are preferable. For example,
this library.