« | » Main « |

Finding fuzzy floats

Sunday 9 July 2017

For a 2D geometry project I needed to find things based on 2D points. Conceptually, I wanted to have a dict that used pairs of floats as the keys. This would work, except that floats are inexact, and so have difficulty with equality checking. The "same" point might be computed in two different ways, giving slightly different values, but I want them to match each other in this dictionary.

I found a solution, though I'm surprised I didn't find it described elsewhere.

The challenges have nothing to do with the two-dimensional aspect, and everything to do with using floats, so for the purposes of explaining, I'll simplify the problem to one-dimensional points. I'll have a class with a single float in it to represent my points.

First let's look at what happens with a simple dict:

>>> from collections import namedtuple
>>> Pt = namedtuple("Pt", "x")
>>> d = {}
>>> d[Pt(1.0)] = "hello"
>>> d[Pt(1.0)]
>>> d[Pt(1.000000000001)]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
KeyError: Pt(x=1.000000000001)

As long as our floats are precisely equal, the dict works fine. But as soon as we get two floats that are meant to represent the same value, but are actually slightly unequal, then the dict is useless to us. So a simple dict is no good.

To get the program running at all, I used a dead-simple structure that I knew would work: a list of key/value pairs. By defining __eq__ on my Pt class to compare with some fuzz, I could find matches that were slightly unequal:

>>> class Pt(namedtuple("Pt", "x")):
...     def __eq__(self, other):
...         return math.isclose(self.x, other.x)
>>> def get_match(pairs, pt):
...     for pt2, val in pairs:
...         if pt2 == pt:
...             return val
...     return None
>>> pairs = [
...     (Pt(1.0), "hello"),
...     (Pt(2.0), "goodbye"),
... ]
>>> get_match(pairs, Pt(2.0))
>>> get_match(pairs, Pt(2.000000000001))

This works, because now we are using an inexact closeness test to find the match. But we have an O(n) algorithm, which isn't great. Also, there's no way to define __hash__ to match that __eq__ test, so our points are no longer hashable.

Trying to make things near each other be equal naturally brings rounding to mind. Maybe that could work? Let's define __eq__ and __hash__ based on rounding the value:

>>> class Pt(namedtuple("Pt", "x")):
...     def __eq__(self, other):
...         return round(self.x, 6) == round(other.x, 6)
...     def __hash__(self):
...         return hash(round(self.x, 6))
>>> d = {}
>>> d[Pt(1.0)] = "apple"
>>> d[Pt(1.0)]
>>> d[Pt(1.00000001)]

Nice! We have matches based on values being close enough, and we can use a dict to get O(1) behavior. But:

>>> d[Pt(1.000000499999)] = "cat"
>>> d[Pt(1.000000500001)]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
KeyError: Pt(x=1.000000500001)

Because rounding has sharp edges (ironically!), there will always be pairs of numbers arbitrarily close together that will round away from each other.

So rounding was close, but not good enough. We can't guarantee that we won't have these pairs of numbers that are in just the wrong place so that rounding will give the wrong answer.

I thought about other data structures: bisection in a sorted list sounded good (O(log n)), but that requires testing for less-than, and I wasn't sure how the fuzziness would affect it. Reading about other geometric algorithms lead to things like k-d trees, which seem exotic and powerful, but I didn't see how to put them to use here.

Then early one morning when I wasn't even trying to solve the problem, I had an idea: round the values two different ways. The problem with rounding was that it failed for values that straddled the rounding boundary. We can round twice, the second time with half the rounding window added, so the two are half out of phase. That way, if the first rounding separates the values, the second rounding won't.

We can't use double-rounding to produce a canonical value for comparisons and hashes, so instead we make a new data structure. We'll use a dict to map points to values, just as we did originally. A second dict will map rounded points to the original points. When we look up a point, if it isn't in the first dict, then round it two ways, and see if either is in the second dict. If it is, then we know what original point to use in the first dict.

That's a lot of words. Here's some code:

class Pt(namedtuple("Pt", "x")):
    # no need for special __eq__ or __hash__

class PointMap:
    def __init__(self):
        self._items = {}
        self._rounded = {}

    JITTERS = [0, 0.5 * 10 ** -ROUND_DIGITS]

    def _round(self, pt, jitter):
        return Pt(round(pt.x + jitter, ndigits=self.ROUND_DIGITS))

    def __getitem__(self, pt):
        val = self._items.get(pt)
        if val is not None:
            return val

        for jitter in self.JITTERS:
            pt_rnd = self._round(pt, jitter)
            pt0 = self._rounded.get(pt_rnd)
            if pt0 is not None:
                return self._items[pt0]

        raise KeyError(pt)

    def __setitem__(self, pt, val):
        self._items[pt] = val
        for jitter in self.JITTERS:
            pt_rnd = self._round(pt, jitter)
            self._rounded[pt_rnd] = pt

This is now O(1), wonderful! I was surprised that I couldn't find something like this already explained somewhere. Perhaps finding values like I need to do is unusual? Or this is well-known, but under some name I couldn't find? Or is it broken in some way I can't see?

« | » Main « |