Support for incremental decoding (original) (raw)

I would like to be able to decode a sequence of token ids incrementally in a decoder-agnostic manner. I haven't found a straightforward way to do this with the current API - the first token is treated differently by some decoders which means that in general

decode([1,2,3]) != decode([1]) + decode([2]) + decode([3])

It would be really nice to have some kind of "continuation" flag to indicate that the result is intended to be be appended to an already-decoded prefix. So that you could have

decode([1,2,3]) == decode([1]) + decode'([2]) + decode'([3])

It would also be nice to have a variant of this that takes either a single u32 id or string token rather than a vec, for related reasons (latter could be used with id_to_token).

I'd love to know if there is another way to achieve this than my current ugly workaround :)

Current workaround

pub(crate) struct Decoder { pub(crate) tokenizer: Tokenizer, prefix_id: u32, prefix: String, }

impl Decoder { pub(crate) fn new(tokenizer: Tokenizer) -> Decoder { let prefix_id = tokenizer.token_to_id("A").unwrap(); Decoder { prefix_id, prefix: tokenizer.decode(vec![prefix_id], false).unwrap(), tokenizer, } }

/// Decode continuation tokens to be added to some existing text
pub(crate) fn decode_continuation(&self, mut ids: Vec<u32>) -> tokenizers::Result<String> {
    // How we handle this depends on the specific decoder's behaviour,
    // see each one's implementation of decode_chain in the tokenizers library.
    match self.tokenizer.get_decoder() {
        Some(ByteLevel(_)) => {
            // Lossless - call standard decode function
            self.tokenizer.decode(ids, true)
        },
        Some(Metaspace(_)) | Some(WordPiece(_)) | Some(BPE(_)) => {
            // For these, the first token in the sequence is treated differently,
            // so we add and then strip a placeholder token.
            ids.insert(0, self.prefix_id);
            let result = self.tokenizer.decode(ids, true)?;
            Ok(result.strip_prefix(&self.prefix).ok_or(DecodingError)?.to_string())
        },
        None => {
            // Just prepend a space
            Ok(format!(" {}", self.tokenizer.decode(ids, true)?))
        },
        _ => Err(UnsupportedTokenizerError.into())
    }
}

}