Quantcast
Channel: Andrew Kelley
Viewing all articles
Browse latest Browse all 61

String Matching based on Compile Time Perfect Hashing in Zig

$
0
0

String Matching based on Compile Time Perfect Hashing in Zig

Inspired by cttrie - Compile time TRIE based string matching, I decided to see what this solution would look like in Zig.

Here's the API I came up with:

const ph = perfectHash([][]const u8{
    "a",
    "ab",
    "abc",
});
switch (ph.hash(target)) {
    ph.case("a") => std.debug.warn("handle the a case"),
    ph.case("ab") => std.debug.warn("handle the ab case"),
    ph.case("abc") => std.debug.warn("handle the abc case"),
    else => unreachable,
}

It notices at compile-time if you forget to declare one of the cases. For example, if I comment out the last item:

const ph = perfectHash([][]const u8{
    "a",
    "ab",
    //"abc",
});

When compiled this gives:

perfect-hashing.zig:147:13: error: case value 'abc' not declared
            @compileError("case value '" ++ s ++ "' not declared");
            ^
perfect-hashing.zig:18:16: note: called from here
        ph.case("abc") => std.debug.warn("handle the abc case\n"),
               ^

It also has runtime safety if you pass in a string that was not prepared for:

const std = @import("std");
const assert = std.debug.assert;

test "perfect hashing" {
    basedOnLength("zzz");
}

fn basedOnLength(target: []const u8) void {
    const ph = perfectHash([][]const u8{
        "a",
        "ab",
        "abc",
    });
    switch (ph.hash(target)) {
        ph.case("a") => @panic("wrong one a"),
        ph.case("ab") => {}, // test pass
        ph.case("abc") => @panic("wrong one abc"),
        else => unreachable,
    }
}

When run this gives:

Test 1/1 perfect hashing...attempt to perfect hash zzz which was not declared
perfect-hashing.zig:156:36: 0x205a21 in ??? (test)
                    std.debug.panic("attempt to perfect hash {} which was not declared", s);
                                   ^
perfect-hashing.zig:15:20: 0x2051bd in ??? (test)
    switch (ph.hash(target)) {
                   ^
perfect-hashing.zig:5:18: 0x205050 in ??? (test)
    basedOnLength("zzz");
                 ^
/home/andy/dev/zig/build/lib/zig/std/special/test_runner.zig:13:25: 0x2238ea in ??? (test)
        if (test_fn.func()) |_| {
                        ^
/home/andy/dev/zig/build/lib/zig/std/special/bootstrap.zig:96:22: 0x22369b in ??? (test)
            root.main() catch |err| {
                     ^
/home/andy/dev/zig/build/lib/zig/std/special/bootstrap.zig:70:20: 0x223615 in ??? (test)
    return callMain();
                   ^
/home/andy/dev/zig/build/lib/zig/std/special/bootstrap.zig:64:39: 0x223478 in ??? (test)
    std.os.posix.exit(callMainWithArgs(argc, argv, envp));
                                      ^
/home/andy/dev/zig/build/lib/zig/std/special/bootstrap.zig:37:5: 0x223330 in ??? (test)
    @noInlineCall(posixCallMainAndExit);
    ^

Tests failed. Use the following command to reproduce the failure:
/home/andy/dev/zig/build/zig-cache/test

So there's the API. How does it work?

Here's the implementation of the perfectHash function:

fn perfectHash(comptime strs: []const []const u8) type {
    const Op = union(enum) {
        /// add the length of the string
        Length,

        /// add the byte at index % len
        Index: usize,

        /// right shift then xor with constant
        XorShiftMultiply: u32,
    };
    const S = struct {
        fn hash(comptime plan: []Op, s: []const u8) u32 {
            var h: u32 = 0;
            inline for (plan) |op| {
                switch (op) {
                    Op.Length => {
                        h +%= @truncate(u32, s.len);
                    },
                    Op.Index => |index| {
                        h +%= s[index % s.len];
                    },
                    Op.XorShiftMultiply => |x| {
                        h ^= x >> 16;
                    },
                }
            }
            return h;
        }

        fn testPlan(comptime plan: []Op) bool {
            var hit = [1]bool{false} ** strs.len;
            for (strs) |s| {
                const h = hash(plan, s);
                const i = h % hit.len;
                if (hit[i]) {
                    // hit this index twice
                    return false;
                }
                hit[i] = true;
            }
            return true;
        }
    };

    var ops_buf: [10]Op = undefined;

    const plan = have_a_plan: {
        var seed: u32 = 0x45d9f3b;
        var index_i: usize = 0;
        const try_seed_count = 50;
        const try_index_count = 50;

        while (index_i < try_index_count) : (index_i += 1) {
            const bool_values = if (index_i == 0) []bool{true} else []bool{ false, true };
            for (bool_values) |try_length| {
                var seed_i: usize = 0;
                while (seed_i < try_seed_count) : (seed_i += 1) {
                    comptime var rand_state = std.rand.Xoroshiro128.init(seed + seed_i);
                    const rng = &rand_state.random;

                    var ops_index = 0;

                    if (try_length) {
                        ops_buf[ops_index] = Op.Length;
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];

                        ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) };
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];
                    }

                    ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) };
                    ops_index += 1;

                    if (S.testPlan(ops_buf[0..ops_index]))
                        break :have_a_plan ops_buf[0..ops_index];

                    const before_bytes_it_index = ops_index;

                    var byte_index = 0;
                    while (byte_index < index_i) : (byte_index += 1) {
                        ops_index = before_bytes_it_index;

                        ops_buf[ops_index] = Op{ .Index = rng.scalar(u32) % try_index_count };
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];

                        ops_buf[ops_index] = Op{ .XorShiftMultiply = rng.scalar(u32) };
                        ops_index += 1;

                        if (S.testPlan(ops_buf[0..ops_index]))
                            break :have_a_plan ops_buf[0..ops_index];
                    }
                }
            }
        }

        @compileError("unable to come up with perfect hash");
    };

    return struct {
        fn case(comptime s: []const u8) usize {
            inline for (strs) |str| {
                if (std.mem.eql(u8, str, s))
                    return hash(s);
            }
            @compileError("case value '" ++ s ++ "' not declared");
        }
        fn hash(s: []const u8) usize {
            if (std.debug.runtime_safety) {
                const ok = for (strs) |str| {
                    if (std.mem.eql(u8, str, s))
                        break true;
                } else false;
                if (!ok) {
                    std.debug.panic("attempt to perfect hash {} which was not declared", s);
                }
            }
            return S.hash(plan, s) % strs.len;
        }
    };
}

Here's what this is doing:

  • Create the concept of a hashing operation. The operations that can happen are:
    • Length - use wraparound addition to add the length of the string to the hash value.
    • Index - use wraparound addition to add the byte from the string at the specified index. Use modulus arithmetic in case the index is out of bounds.
    • XorShiftMultiply - perform a right shift by 16 bits of the hash value and then xor with the specified value (which will come from a random number generator executed at compile time).
  • Next, we're going to try a series of "plans" which is a sequence of operations strung together. We have this function testPlan which performs the hash operations on all the strings and sees if there are any collisions. If we ever find a plan that results in no collisions, we have found a perfect hashing strategy.
  • First we test a plan that is simply the Length operation. If this works, then all the hash function has to do is take the length of the string mod the number of strings. Easy. You may notice this is true for the above example. Don't worry, I have a more complicated example below.
  • Next we iterate over the count of how many different bytes we are willing to look at in the hash function. We start with 0. If we can xor the length with a random number and it fixes the collisions, then we're done. We try 50 seeds before giving up and deciding to inspect a random byte from the string in the hash function. If the addition of inspecting a random byte from the string to hash doesn't solve the problem, we try 50 seeds in order to choose different random byte indexes. If that still doesn't work, we look at 2 random bytes, again trying 50 different seeds with these 2 bytes. And so on until every combination of 50 seeds x 50 bytes inspected, at which point we give up and emit a compile error "unable to come up with perfect hash".

We can use @compileLog to see the plan that the function came up with:

for (plan) |op| {
    @compileLog(@TagType(Op)(op));
}

This outputs:

| @TagType(Op).Length

So this means that the hash function only has to look at the length for this example.

The nice thing about this is that the plan is all known at compile time. Indeed, you can see that the hash function uses inline for to iterate over the operations in the plan. This means that LLVM is able to fully optimize the hash function.

Here's what it gets compiled to in release mode for x86_64:

0000000000000000 <basedOnLength>:
   0:	89 f0                	mov    %esi,%eax
   2:	b9 ab aa aa aa       	mov    $0xaaaaaaab,%ecx
   7:	48 0f af c8          	imul   %rax,%rcx
   b:	48 c1 e9 21          	shr    $0x21,%rcx
   f:	8d 04 49             	lea    (%rcx,%rcx,2),%eax
  12:	29 c6                	sub    %eax,%esi
  14:	48 89 f0             	mov    %rsi,%rax
  17:	c3                   	retq   
  18:	0f 1f 84 00 00 00 00 	nopl   0x0(%rax,%rax,1)
  1f:	00 

You can see there is not even a jump instruction in there. And it will output numbers in sequential order from 0, so that the switch statement can be a jump table. Here is the LLVM IR of the switch:

 %2 = call fastcc i64 @"perfectHash((struct []const []const u8 constant))_hash"(%"[]u8"* %0), !dbg !736
  store i64 %2, i64* %1, align 8, !dbg !736
  %3 = load i64, i64* %1, align 8, !dbg !740
  switch i64 %3, label %SwitchElse [
    i64 1, label %SwitchProng
    i64 2, label %SwitchProng1
    i64 0, label %SwitchProng2
  ], !dbg !740

How about a harder example?

@setEvalBranchQuota(100000);
const ph = perfectHash([][]const u8{
    "one",
    "two",
    "three",
    "four",
    "five",
});
switch (ph.hash(target)) {
    ph.case("one") => std.debug.warn("handle the one case"),
    ph.case("two") => std.debug.warn("handle the two case"),
    ph.case("three") => std.debug.warn("handle the three case"),
    ph.case("four") => std.debug.warn("handle the four case"),
    ph.case("five") => std.debug.warn("handle the five case"),
    else => unreachable,
}

This example is interesting because there are 2 pairs of length collisions (one/two, four/five) and 2 pairs of byte collisions (two/three, four/five).

Here we have to use @setEvalBranchQuota because it takes a bit of computation to come up with the answer.

Again, the hash function comes up with a mapping to 0, 1, 2, 3, 4 (but not necessarily in the same order as specified):

 %2 = call fastcc i64 @"perfectHash((struct []const []const u8 constant))_hash.11"(%"[]u8"* %0), !dbg !749
  store i64 %2, i64* %1, align 8, !dbg !749
  %3 = load i64, i64* %1, align 8, !dbg !753
  switch i64 %3, label %SwitchElse [
    i64 4, label %SwitchProng
    i64 2, label %SwitchProng1
    i64 3, label %SwitchProng2
    i64 0, label %SwitchProng3
    i64 1, label %SwitchProng4
  ], !dbg !753

And the optimized assembly:

0000000000000020 <basedOnOtherStuff>:
  20:	48 83 fe 22          	cmp    $0x22,%rsi
  24:	77 0b                	ja     31 <basedOnOtherStuff+0x11>
  26:	b8 22 00 00 00       	mov    $0x22,%eax
  2b:	31 d2                	xor    %edx,%edx
  2d:	f7 f6                	div    %esi
  2f:	eb 05                	jmp    36 <basedOnOtherStuff+0x16>
  31:	ba 22 00 00 00       	mov    $0x22,%edx
  36:	0f b6 04 17          	movzbl (%rdi,%rdx,1),%eax
  3a:	05 ef 3d 00 00       	add    $0x3def,%eax
  3f:	b9 cd cc cc cc       	mov    $0xcccccccd,%ecx
  44:	48 0f af c8          	imul   %rax,%rcx
  48:	48 c1 e9 22          	shr    $0x22,%rcx
  4c:	8d 0c 89             	lea    (%rcx,%rcx,4),%ecx
  4f:	29 c8                	sub    %ecx,%eax
  51:	c3                   	retq   

We have a couple of jumps here, but no loop. This code looks at only 1 byte from the target string to determine the corresponding case index. We can see that more clearly with the @compileLog snippet from earlier:

| @TagType(Op).XorShiftMultiply
| @TagType(Op).Index

So it initializes the hash with a constant value of predetermined random bits, then adds the byte from a randomly chosen index of the string. Using @compileLog(plan[1].Index) I determined that it is choosing the value 34, which means that:

  • For one, 34 % 3 == 1, it looks at the 'n'
  • For two, 34 % 3 == 1, it looks at the 'w'
  • For three, 34 % 5 == 4, it looks at the 'e'
  • For four, 34 % 4 == 2, it looks at the 'u'
  • For five, 34 % 4 == 2, it looks at the 'i'

So the perfect hashing exploited that these bytes are all different, and when combined with the random constant that it found, the final modulus spreads out the value across the full range of 0, 1, 2, 3, 4.

Can we do better?

There are lots of ways this can be improved - this is just a proof of concept. For example, we could start by looking for a perfect hash in the subset of bytes that are within the smallest string length. For example, if we used index 1 in the above example, we could avoid the jumps and remainder division instructions in the hash function.

Another way this can be improved, to reduce compile times, is to have the perfectHash function accept the RNG seed as a parameter. If the seed worked first try, great. Otherwise the function would find the seed that does work, and emit a compile error instructing the programmer to switch the seed argument to the good one, saving time in future compilations.

Conclusion

If you like this, you should check out Zig. Consider supporting my work on Patreon.


Viewing all articles
Browse latest Browse all 61

Trending Articles