Levenshtein Distance in Prolog

I recently came across a neat article about calculating the Levenshtein distance between strings in Clojure. I liked the process of making the algorithm more efficient, but I was somewhat surprised they opted to not look into memoization. They said that “memorisation (sic) introduces state and detracts for the functional purity of the solution”.

While memoization does introduce some state, it can often be nicely hidden away and, in my opinion, can enhance the “functional purity” of a solution. Indeed, the annoyance of threading through intermediate results and depth limits seems to complicate the solution much more.

On the other hand, the built-in memoization options for Clojure are somewhat crude, so it perhaps wouldn’t be as pleasant. However, I suspected that memoization would be a great fit for this problem, since it has a good dynamic programming solution (the “matrix” approach in the original article) and memoization is exactly for this sorts of problem.

Naive memoization is a bit annoying in this case in Clojure, since we only want to keep the minimum result for a given pair of strings, not just the first we find. In Prolog however, with mode-directed tabling, we can get the best of both worlds – a very straight-forward, declarative description of the algorithm that is still efficient!

“Tabling” can be thought of as the Prolog term for memoization. In implementations such as SWI-Prolog and XSB though, there are some substantial improvements to how tabling works that allows it to be much more than just a simple cache. “Mode-directed” tabling is one such enhancement. It allows one to specify in fine-grained detail exactly what to add to the cache – we can keep just the first answer found, the smallest, the sum of answers, or any other relation we desire.

:- module(levenshtein, [levenshtein_dist/3]).

% Wrapper predicate to convert the strings into lists of character codes
levenshtein_dist(S1, S2, D) :-
    string_codes(S1, Cs1),
    string_codes(S2, Cs2),
    % ld/3, defined below, does the actual work
    ld(Cs1, Cs2, D).

% cache the code list arguments normally, but just keep the minimum value found for distance
:- table ld(_, _, min).

ld([], Cs, D) :- length(Cs, D). % base case #1
ld(Cs, [], D) :- length(Cs, D). % base case #2
% drop first charcter of left string
ld([_|Cs1], Cs2, D) :-
    ld(Cs1, Cs2, D0),
    D is D0 + 1.
% drop first character of right string
ld(Cs1, [_|Cs2], D) :-
    ld(Cs1, Cs2, D0),
    D is D0 + 1.
% "edit" step
ld([X|Cs1], [Y|Cs2], D) :-
    ( X == Y -> K = 0 ; K = 1),
    ld(Cs1, Cs2, D0),
    D is D0 + K.

Note that instead of having to explicitly call our predicate three times and then take the minimum value, we can just describe the three recursive cases directly and let the tabling take care of minimizing!

Let’s test this predicate out and see how well it scales.

test(S1, S2) :-
    levenshtein_dist(S1, S2, D),
    format("~s -> ~s = ~d~n", [S1, S2, D]).

?- test("lawn", "flaw").
?- test("abcdefghi", "123456789").
?- test("a23456789", "123456789").
?- test("frog", "fog").
?- test("hypotenuse", "hypertension").
lawn -> flaw = 2
abcdefghi -> 123456789 = 9
a23456789 -> 123456789 = 1
frog -> fog = 1
hypotenuse -> hypertension = 6

Results seem correct, but how’s the timing?

long_test(L, D) :-
    length(A, L),
    length(B, L),
    maplist(=(0'a), A),
    maplist(=(0'b), B),
    call_time(ld(A, B, D), Time),
    format("length ~d in ~3fs wall, ~3fs CPU (~d inferences)~n",
           [L, Time.wall, Time.cpu, Time.inferences]).

?- long_test(50, _), long_test(100, _),
   long_test(200, _), long_test(400, _),
   long_test(800, _).
length 50 in 0.063s wall, 0.063s CPU (160530 inferences)
length 100 in 0.361s wall, 0.360s CPU (633493 inferences)
length 200 in 1.995s wall, 1.994s CPU (2516943 inferences)
length 400 in 14.659s wall, 14.392s CPU (10033843 inferences)
length 800 in 123.244s wall, 122.127s CPU (40067643 inferences)

Seems roughly quadratic…let’s try throwing Gnuplot at it & see how it looks.

size time
50 0.063
100 0.361
200 1.995
400 14.659
800 123.244
approx(x) = c + b*x + a*x**2
fit approx(x) timings using 1:2 via c,b,a
plot timings, approx(x)

Quadratic complexity seems reasonable, since the time is going to be O(mn) and we’re having both m and n increase together here.

I was pretty happy to be able to write express the algorithm very directly in Prolog – I’d argue even closer to the mathematical description than the idealized Clojure version – and still get decent asymptotic performance. Mode-directed tabling is an extremely powerful feature of Prolog; like many things in Prolog, I’m still learning how to use it effectively, but when it clicks it can be mind-blowing.