Example CTC Decoder in Python (original) (raw)

"""

Author: Awni Hannun

This is an example CTC decoder written in Python. The code is

intended to be a simple example and is not designed to be

especially efficient.

The algorithm is a prefix beam search for a model trained

with the CTC loss function.

For more details checkout either of these references:

https://distill.pub/2017/ctc/#inference

https://arxiv.org/abs/1408.2873

"""

import numpy as np

import math

import collections

NEG_INF = -float("inf")

def make_new_beam():

fn = lambda : (NEG_INF, NEG_INF)

return collections.defaultdict(fn)

def logsumexp(*args):

"""

Stable log sum exp.

"""

if all(a == NEG_INF for a in args):

return NEG_INF

a_max = max(args)

lsp = math.log(sum(math.exp(a - a_max)

for a in args))

return a_max + lsp

def decode(probs, beam_size=100, blank=0):

"""

Performs inference for the given output probabilities.

Arguments:

probs: The output probabilities (e.g. post-softmax) for each

time step. Should be an array of shape (time x output dim).

beam_size (int): Size of the beam to use during inference.

blank (int): Index of the CTC blank label.

Returns the output label sequence and the corresponding negative

log-likelihood estimated by the decoder.

"""

T, S = probs.shape

probs = np.log(probs)

# Elements in the beam are (prefix, (p_blank, p_no_blank))

# Initialize the beam with the empty sequence, a probability of

# 1 for ending in blank and zero for ending in non-blank

# (in log space).

beam = [(tuple(), (0.0, NEG_INF))]

for t in range(T): # Loop over time

# A default dictionary to store the next step candidates.

next_beam = make_new_beam()

for s in range(S): # Loop over vocab

p = probs[t, s]

# The variables p_b and p_nb are respectively the

# probabilities for the prefix given that it ends in a

# blank and does not end in a blank at this time step.

for prefix, (p_b, p_nb) in beam: # Loop over beam

# If we propose a blank the prefix doesn't change.

# Only the probability of ending in blank gets updated.

if s == blank:

n_p_b, n_p_nb = next_beam[prefix]

n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)

next_beam[prefix] = (n_p_b, n_p_nb)

continue

# Extend the prefix by the new character s and add it to

# the beam. Only the probability of not ending in blank

# gets updated.

end_t = prefix[-1] if prefix else None

n_prefix = prefix + (s,)

n_p_b, n_p_nb = next_beam[n_prefix]

if s != end_t:

n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)

else:

# We don't include the previous probability of not ending

# in blank (p_nb) if s is repeated at the end. The CTC

# algorithm merges characters not separated by a blank.

n_p_nb = logsumexp(n_p_nb, p_b + p)

# *NB* this would be a good place to include an LM score.

next_beam[n_prefix] = (n_p_b, n_p_nb)

# If s is repeated at the end we also update the unchanged

# prefix. This is the merging case.

if s == end_t:

n_p_b, n_p_nb = next_beam[prefix]

n_p_nb = logsumexp(n_p_nb, p_nb + p)

next_beam[prefix] = (n_p_b, n_p_nb)

# Sort and trim the beam before moving on to the

# next time-step.

beam = sorted(next_beam.items(),

key=lambda x : logsumexp(*x[1]),

reverse=True)

beam = beam[:beam_size]

best = beam[0]

return best[0], -logsumexp(*best[1])

if __name__ == "__main__":

np.random.seed(3)

time = 50

output_dim = 20

probs = np.random.rand(time, output_dim)

probs = probs / np.sum(probs, axis=1, keepdims=True)

labels, score = decode(probs)

print("Score {:.3f}".format(score))