Jay Taylor's notes

back to listing index

Hacking Go Maps for Fun and Profit

[web search]
Original source (lukechampine.com)
Tags: types hacking golang go lukechampine.com
Clipped on: 2017-05-29

Hacking Go Maps for Fun and Profit

Last week I spent an irresponsible amount of time slacking off from my job and instead dove into the deep magic of Go’s map implementation. I was driven by a simple desire: to select a random map value as easily as selecting a random slice value. In pursuit of that goal, I tumbled down the rabbit hole that is the runtime package, and battled the demons therein. Eventually I returned victorious, and with a newfound appreciation for the humble map.

This is a super long post detailing that journey. It assumes some familiarity with the unsafe package, but nothing you can’t pick up by spending 5 minutes reading its godoc page. If you don’t care about all the neat stuff going on under the hood, you can get the finished product here:

randmap - Truly random map access and iteration for Go

Part 1: The Problem

As you may know, selecting a random value from a slice is easily accomplished via Intn (from math/rand), which returns a random integer in the range [0,n):

func randSliceValue(xs []string) string {
	return xs[rand.Intn(len(xs))]
}

This is nice because it runs in constant time and space. But there is no obvious equivalent for maps! That’s because we only have two ways of accessing map data: lookups (e.g. m["foo"]) and range. So, given only these accessors and a random integer, how can we select a random key? (Note that if we can select a random map key, we can trivially select a random map value.)

One approach would be to flatten the map; then we can reuse the slice approach:

func randMapKey(m map[string]int) string {
	mapKeys = make([]string, 0, len(m)) // pre-allocate exact size
	for key := range m {
		mapKeys = append(mapKeys, key)
	}
	return mapKeys[rand.Intn(len(mapKeys))]
}

It is easy to verify that this code will indeed produce random keys. But the simplicity comes at a performance cost: O(n) time and O(n) space.

A slightly better approach relies on the semantics of range. range will visit every key/value in the map exactly once, in unspecified order. Knowing this, we can use our random integer as a counter that decrements after each iteration:

func randMapKey(m map[string]int) string {
	r := rand.Intn(len(m))
	for k := range m {
		if r == 0 {
			return k
		}
		r--
	}
	panic("unreachable")
}

This runs in O(n) time and O(1) space, which is probably acceptable in most scenarios. But now we have a new problem: we can’t write a generic version of this function! For the “flattening” approach, we can use reflection to make it generic, thanks to the MapKeys method:

func randMapKey(m interface{}) interface{} {
	mapKeys := reflect.ValueOf(m).MapKeys()
	return mapKeys[rand.Intn(len(mapKeys))].Interface()
}

But reflect does not provide us with a generic range, so we can’t implement the other “iteration” approach in a generic fashion.

Unless…

Unless we harness the power of unsafe. ◕‿◕

Part 2: Down the Rabbit Hole

Among other things, unsafe allows us to subvert Go’s type system. That is, it allows us to treat a value of type X as though it actually has type Y. And the beautiful thing is that this works for Go’s builtin types (like strings, slices, and maps) just as well as user-defined types. Note that this goes both ways: we can convert, say, a []byte to some other type and fiddle with it, but we can also tell Go that an arbitrary object is a []byte, and Go will believe us!

We can use this trick access an object's underlying memory, which I do in a few of my packages. But I digress. Today, we are only casting in the other direction. Specifically, we will cast a map (of any type) to a local copy of its runtime definition. (Leave that link open, it’ll be a useful reference.) Then, we can access its data directly and hopefully devise an efficient way of selecting a random key.

Examining hashmap.go, we see a type called hiter and the functions mapiterinit and mapiternext. This is the code that runs when you range over a map. hiter is an iterator; mapiterinit initializes it, and mapiternext seeks to the next iterator position. So now we have a plan:

  1. Cast the map to an hmap
  2. Create an iterator with new(hiter) and initialize it with mapiterinit
  3. Generate a random number n in the range [0,len(m)]
  4. Run mapiternext n times
  5. Return the key at the iterator’s current position

The nice thing about this approach is that we don’t need to understand the gory details of mapiterinit and mapiternext; we can just copy all the relevant code from hashmap.go and it should work fine. (Actually, the compiler will complain about some missing runtime magic, like atomic.Or8, so we’ll have to strip that stuff out in order to placate it.) After that, we just need a bit more code to convert to and from interface{} values. Specifically, we want to extract the map type and value from the interface{} passed to our randMapKey function, and we want to pack the final iterator position into an interface{} that the caller can work with. These helper functions are pretty easy to implement if you are familiar with interface internals:

// runtime representation of an interface{}
type emptyInterface struct {
	typ unsafe.Pointer
	val unsafe.Pointer
}

func mapTypeAndValue(m interface{}) (*maptype, *hmap) {
	ei := (*emptyInterface)(unsafe.Pointer(&m))
	return (*maptype)(ei.typ), (*hmap)(ei.val)
}

func iterKey(it *hiter) interface{} {
	ei := emptyInterface{
		typ: unsafe.Pointer(it.t.key), // it.t is the maptype
		val: it.key,
	}
	return *(*interface{})(unsafe.Pointer(&ei))
}

Now we are finally ready to implement a generic randMapKey that runs in O(n) time and O(1) space:

func randMapKey(m interface{}) interface{} {
	// get map internals
	t, h := mapTypeAndValue(m)
	// initialize iterator
	it := new(hiter)
	mapiterinit(t, h, it)
	// advance iterator a random number of times
	r := rand.Intn(h.count) // h.count == len(m)
	for i := 0; i < r; i++ {
		mapiternext(it)
	}
	// return current iterator key as an interface{}
	return iterKey(it)
}

It works! But of course, this all comes at a cost. It’s called unsafe for a reason – actually, many reasons, but the reason here is that we’re assuming the underlying structure and operation of Go’s map type. If Go 1.8 adds a field to hmap or maptype or hiter, it could totally break our code. And in fact, Go 1.8 is adding a field to hmap, so our code will have to be updated as soon as Go 1.8 is released. This has pretty clear ramifications when it comes to publishing our randMapKey package, since it’s only fair to assume that some of its users will be running Go 1.7, and some will be running Go 1.8. So it’s crucial that we use build tags to ensure that our package doesn’t blow up after a user upgrades their Go version.

We could call it a day at this point, but something is nagging at me… can we do better than O(n) time?

Part 3: Yes

In fact, we can do it in constant time, just like slices (although the constant is quite a bit larger). But in order to get there, we need to get a better understanding of how Go maps work.

An hmap is an array of buckets. The number of buckets is always a power of two (specifically, it’s 1 << h.B). Since the number of buckets changes as the map grows, the bucket array is represented as an unsafe.Pointer in hmap. Each bucket holds 8 “cells” (key/value pairs), which may or may not contain valid data, and a tophash value for each cell which identifies its contents. The struct definition of a bucket is a little confusing; it looks like this:

const bucketCnt = 8 // number of cells per bucket

type bmap struct {
	tophash [bucketCnt]uint8
	// Followed by bucketCnt keys and then bucketCnt values.
	// NOTE: packing all the keys together and then all the values together makes the
	// code a bit more complicated than alternating key/value/key/value/... but it allows
	// us to eliminate padding which would be needed for, e.g., map[int64]int8.
	// Followed by an overflow pointer.
}

As you can see, the actual key and value data are not represented in the struct! This is because the types of the keys and values are not known, so the compiler would not know what size bmap should be. If the map were a map[string]int, then you can imagine bmap like this:

type bmap struct {
	tophash  [bucketCnt]uint8
	keys     [bucketCnt]string
	values   [bucketCnt]int
	overflow *bmap
}

To insert a key/value pair, we first calculate the hash of the key, which is represented as a uintptr. (Each type in the runtime has an associated hash function.) Then we need to decide which bucket to use. Since the number of buckets is always a power of two, we can use a simple bitmask:

// h is an hmap, t is a maptype
hash := t.key.alg.hash(key, uintptr(h.hash0))
bucketIndex := hash & (uintptr(1) << h.B - 1)

After selecting a bucket, we need to find an open cell. We iterate through each cell, checking the values of tophash. tophash contains the top 8 bits of the hash of the key stored in the cell, with the special case that tophash == 0 indicates an empty cell. Once we find an empty cell, we store our key and value there, and set the tophash accordingly. Here is the (vastly simplified) algorithm:

// calculate tophash
top := uint8(hash >> (unsafe.Sizeof(hash)*8 - 8))

// seek to offset of bucketIndex in h.buckets
b := (*bmap)(unsafe.Pointer(uintptr(h.buckets) + bucketIndex*uintptr(t.bucketsize)))

// iterate through the cells of b. If a tophash matches top, it means we've
// already inserted a value with this key, so overwrite it. Otherwise, store
// the key/value in the first empty cell.
for i := 0; i < bucketCnt; i++ {
	if b.tophash[i] == top {
		// overwrite the existing value
		// [ code omitted ]
		return
	} else if b.tophash[i] == 0 {
		// insert the new key/value
		// [ code omitted ]
		b.tophash[i] = top
		h.count++
		return
	}
}

As you might expect, to retrieve a map value, we just repeat this process, but we return the data in the cell instead of overwriting/inserting it.

Ok, now we can return to improving our randMapKey function. Recall that selecting a random slice element is easy; just pick a random index. Well, if you squint, h.buckets is basically a slice; it’s a contiguous array of key/value cells. The biggest difference is that some of the cells are empty. So we want to pick a random index, but somehow avoid the empty cells.

One solution would be to simply iterate forward if we land on an empty cell, seeking until we find a non-empty cell. This is actually what mapiterinit does! But this approach has a serious flaw. Consider the case of a bucket containing only two valid cells:

[foo] [bar] [---] [---] [---] [---] [---] [---]

If we pick a random index [0,8) and skip over empty cells, what will happen? Well, if we pick 0, we’ll get foo, and if we pick 1, we’ll get bar. But if we pick any other index, we will seek forward, wrap around to index 0, and land on foo! So even though we used a random index, we’re 7x more likely to land on foo than bar. This is absolutely unacceptable for our purposes.

Fortunately, there is an alternative solution: we pick a random index, and if the cell is empty, we simply try again with a new random index. On average, we will pick a valid cell in k/n tries, where k is the number of cells and n is the number of non-empty cells. And of course, this algorithm is guaranteed to produce a uniform random distribution of cells, since each index is equally likely to be selected.

Let’s use this algorithm for our new randMapKey. We just need one more helper function for extracting the key data from the cell. I’ll also define an add function that makes pointer math a little more readable:

func add(p unsafe.Pointer, x uintptr) unsafe.Pointer {
	return unsafe.Pointer(uintptr(p) + x)
}

func cellKey(t *maptype, b *bmap, i int) interface{} {
	// dataOffset is where the cell data begins in a bmap.
	const dataOffset = unsafe.Offsetof(struct {
		tophash [bucketCnt]uint8
		cells   int64
	}{}.cells)

	k := add(unsafe.Pointer(b), dataOffset+uintptr(i)*uintptr(t.keysize))
	if t.indirectkey {
		// if the map's key type is too big, a pointer will be stored in
		// the map instead of the actual data. In that case, we need to
		// dereference the pointer.
		k = *(*unsafe.Pointer)(k)
	}

	ei := emptyInterface{
		typ: unsafe.Pointer(t.key),
		val: k,
	}
	return *(*interface{})(unsafe.Pointer(&ei))
}

func randMapKey(m interface{}) interface{} {
	// get map internals
	t, h := mapTypeAndValue(m)
	numBuckets := 1 << h.B

	// loop until we hit a valid cell
	for {
		// pick random indices
		bucketIndex := rand.Intn(numBuckets)
		cellIndex := rand.Intn(bucketCnt)

		// lookup cell
		b := (*bmap)(add(h.buckets, uintptr(bucketIndex)*uintptr(t.bucketsize)))
		if b.tophash[cellIndex] == 0 {
			// cell is empty; try again
			continue
		}
		return cellKey(t, b, cellIndex)
	}
}

Great, that totally works! Now we’re done, right?

Part 4: No

I cheated in that playground link above. The example map contains only the integers [0,8), which I know will not collide when hashed. So what happens if two different map keys hash to the same value? Answer: overflow buckets. If you revisit the definition of bmap, you’ll see that it has an overflow pointer to another bmap at the end. In the event of a collision, a new bucket will be allocated and chained onto the old one. Then, during lookup, we first check that the key in the cell matches the key we’re looking up. If it doesn’t, we move to the next overflow bucket and try again.

How does this affect our randMapKey function? Not too dramatically, in fact. We just need to add a dimension. Before, we were selecting a random cell from a random bucket. Now, we need to select a random cell from a random bucket from a random overflow chain. Visually:

           bucket0   bucket1   bucket2   bucket3
overflow0 [|||||||] [|||||||] [|||||||] [|||||||]
overflow1 [|||||||]           [|||||||]
overflow2                     [|||||||]

Before, we were only selecting from the top row – 4 buckets. Now, we need to select from the entire grid – 12 buckets – even though some of those buckets don’t exist. It’s the same as before with the empty cells; if an overflow bucket doesn’t exist, we just start over with a new random number.

There is one annoyance, though, which is that we don’t know how many overflow buckets there are in advance; hmap doesn’t contain a maxOverflow integer that we can multiply by. Instead, we need to calculate it ourselves, which we can do by iterating through every bucket and following their overflow pointers until we hit nil. This adds some startup cost, but there’s no way around it. The code looks like this:

func (b *bmap) overflow(t *maptype) *bmap {
	offset := uintptr(t.bucketsize)-unsafe.Sizeof(uintptr(0))
	return *(**bmap)(add(unsafe.Pointer(b), offset))
}

func maxOverflow(t *maptype, h *hmap) int {
	numBuckets := uintptr(1 << h.B)
	max := 0
	for i := uintptr(0); i < numBuckets; i++ {
		over := 0
		b := (*bmap)(add(h.buckets, i*uintptr(t.bucketsize)))
		for b = b.overflow(t); b != nil; over++ {
			b = b.overflow(t)
		}
		if over > max {
			max = over
		}
	}
	return max
}

Now, randMapKey looks like this:

func randMapKey(m interface{}) interface{} {
	// get map internals
	t, h := mapTypeAndValue(m)
	numBuckets := 1 << h.B
	numOver := maxOverflow(t, h) + 1 // add 1 to account for "base" bucket

	// loop until we hit a valid cell
loop:
	for {
		// pick random indices
		bucketIndex := rand.Intn(numBuckets)
		overIndex := rand.Intn(numOver)
		cellIndex := rand.Intn(bucketCnt)

		// seek to index in h.buckets
		b := (*bmap)(add(h.buckets, uintptr(bucketIndex)*uintptr(t.bucketsize)))

		// seek to index in overflow chain
		for i := 0; i < overIndex; i++ {
			b = b.overflow(t)
			if b == nil {
				// invalid bucket; try again
				continue loop
			}
		}

		// lookup cell
		if b.tophash[cellIndex] == 0 {
			// cell is empty; try again
			continue loop
		}
		return cellKey(t, b, cellIndex)
	}
}

We can confirm that this works for maps that contain overflow buckets. Phew! I was worried that we would need a Part 5…

Part 5: Of Course There’s a Part 5

(This is the last part, I promise.)

Go maps are highly optimized, and one of the optimizations is something called “incremental copying.” Basically, when a map is full and you try to insert a new element, the Go runtime will immediately allocate a new bucket array (twice as large as the old one) as store the new key/value there. But it won’t copy the cells over from the old bucket; instead, each time you insert or delete an element, its bucket (and any chained overflow buckets) will first be copied (“evacuated”) into the new memory. Once all of the buckets have been evacuated, h.oldbuckets will be nil.

I’m sure you see the problem: up until now, we’ve only been selecting cells from h.buckets. In order to cover all possible values, we also need to check h.oldbuckets. There’s three changes we need to make:

  1. When we select a bucket, we need to check if its corresponding oldbucket has been copied over. If it hasn’t, we need to select from the oldbucket.
  2. If we select from an oldbucket, we need to check where the cell will eventually be evacuated to. If it is destined for the bucket that we originally selected, we can return it. Otherwise, try again. (This is necessary to prevent oversampling, since each oldbucket expands to two new buckets.)
  3. maxOverflow needs to return the maximum of h.buckets and h.oldbuckets.

Fortunately, these aren’t too hard to implement. First, we’ll modify maxOverflow:

func maxOverflow(t *maptype, h *hmap) int {
	numBuckets := uintptr(1 << h.B)
	max := 0
	for i := uintptr(0); i < numBuckets; i++ {
		over := 0
		b := (*bmap)(add(h.buckets, i*uintptr(t.bucketsize)))
		for b = b.overflow(t); b != nil; over++ {
			b = b.overflow(t)
		}
		if over > max {
			max = over
		}
	}

	// check oldbuckets too, if it exists
	if h.oldbuckets != nil {
		for i := uintptr(0); i < numBuckets/2; i++ {
			var over int
			b := (*bmap)(add(h.oldbuckets, i*uintptr(t.bucketsize)))
			if evacuated(b) {
				// we already counted this bucket in the first loop
				continue
			}
			for b = b.overflow(t); b != nil; over++ {
				b = b.overflow(t)
			}
			if over > max {
				max = over
			}
		}
	}
	return max
}

And at last, we can write the final version of randMapKey. When we check for the unevacuated oldbucket, we’ll set a flag that tells us to check the future destination of the cell:

func randMapKey(m interface{}) interface{} {
	// get map internals
	t, h := mapTypeAndValue(m)
	numBuckets := uintptr(1 << h.B)
	numOver := maxOverflow(t, h) + 1 // add 1 to account for "base" bucket

	// loop until we hit a valid cell
loop:
	for {
		// pick a random index
		bucketIndex := uintptr(rand.Intn(int(numBuckets)))
		overIndex := rand.Intn(numOver)
		cellIndex := rand.Intn(bucketCnt)

		// seek to index in h.buckets
		b := (*bmap)(add(h.buckets, bucketIndex*uintptr(t.bucketsize)))

		// if the oldbucket hasn't been evacuated, then we need to use that
		// pointer instead.
		usingOldBucket := false
		if h.oldbuckets != nil {
			numOldBuckets := numBuckets / 2
			oldBucketIndex := bucketIndex & (numOldBuckets - 1)
			oldB := (*bmap)(add(h.oldbuckets, oldBucketIndex*uintptr(t.bucketsize)))
			if !evacuated(oldB) {
				b = oldB
				usingOldBucket = true
			}
		}

		// seek to index in overflow chain
		for i := 0; i < overIndex; i++ {
			b = b.overflow(t)
			if b == nil {
				// invalid bucket; try again
				continue loop
			}
		}

		// lookup cell
		if b.tophash[cellIndex] == 0 {
			// cell is empty; try again
			continue loop
		}

		// grab key and dereference if necessary (same as cellKey)
		k := add(unsafe.Pointer(b), dataOffset+uintptr(cellIndex)*uintptr(t.keysize))
		if t.indirectkey {
			k = *(*unsafe.Pointer)(k)
		}

		// if this is an old bucket, we need to check whether this key is destined
		// for the new bucket. Otherwise, we will have a 2x bias towards oldbucket
		// values, since two different bucket selections can result in the same
		// oldbucket.
		if usingOldBucket {
			hash := t.key.alg.hash(k, uintptr(h.hash0))
			if hash&(numBuckets-1) != bucketIndex {
				// this key is destined for a different bucket
				continue loop
			}
		}

		// pack key into interface{} (same as cellKey)
		ei := emptyInterface{
			typ: unsafe.Pointer(t.key),
			val: k,
		}
		return *(*interface{})(unsafe.Pointer(&ei))
	}
}

Not too bad, all considered! If you missed the link at the top, the full code is available here.

I hope you enjoyed this journey through the map type. In a future post, we’ll cover the other cool aspect of randmap: efficient random iteration. In the meantime, please let me know what could be improved about this post!