GitHub - openvinotoolkit/openvino_tokenizers: OpenVINO Tokenizers extension (original) (raw)

OpenVINO Tokenizers

Downloads Anaconda-Server Badge

OpenVINO Tokenizers adds text processing operations to OpenVINO.

Features

Installation

(Recommended) Create and activate virtual env:

python3 -m venv venv source venv/bin/activate

or

conda create --name openvino_tokenizers conda activate openvino_tokenizers

Minimal Installation

Use minimal installation when you have a converted OpenVINO tokenizer:

pip install openvino-tokenizers

or

conda install -c conda-forge openvino openvino-tokenizers

Convert Tokenizers Installation

If you want to convert HuggingFace tokenizers into OpenVINO tokenizers:

pip install openvino-tokenizers[transformers]

or

conda install -c conda-forge openvino openvino-tokenizers && pip install transformers[sentencepiece] tiktoken

Install Pre-release Version

Use openvino-tokenizers[transformers] to install tokenizers conversion dependencies.

pip install --pre -U openvino openvino-tokenizers --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly

Build and Install from Source

Using OpenVINO PyPI package

openvino-tokenizers build depends on openvino package which will be automatically installed from PyPI during the build process. To install unreleased versions, you would need to install openvino package from the nightly distribution channel using --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly

git clone https://github.com/openvinotoolkit/openvino_tokenizers.git cd openvino_tokenizers pip install . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly

This command is the equivalent of minimal installation. Install tokenizers conversion dependencies if needed:

pip install transformers[sentencepiece] tiktoken

⚠️ Latest commit of OpenVINO Tokenizers might rely on features that are not present in the release OpenVINO version. Use a nightly build of OpenVINO or build OpenVINO Tokenizers from a release branch if you have issues with the build process.

Using OpenVINO archive

Install OpenVINO archive distribution. Use --no-deps to avoid OpenVINO installation from PyPI into your current environment.--extra-index-url is needed to resolve build dependencies only.

source path/to/installed/openvino/setupvars.sh git clone https://github.com/openvinotoolkit/openvino_tokenizers.git cd openvino_tokenizers pip install --no-deps . --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly

This command is the equivalent of minimal installation. Install tokenizers conversion dependencies if needed:

pip install transformers[sentencepiece] tiktoken

⚠️ Latest commit of OpenVINO Tokenizers might rely on features that are not present in the release OpenVINO version. Use a nightly build of OpenVINO or build OpenVINO Tokenizers from a release branch if you have issues with the build process.

Build and install for development

Using OpenVINO PyPI package

git clone https://github.com/openvinotoolkit/openvino_tokenizers.git cd openvino_tokenizers pip install -e .[all] --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly

verify installation by running tests

cd tests/ pytest .

Using OpenVINO archive

Install OpenVINO archive distribution. Use --no-deps to avoid OpenVINO installation from PyPI into your current environment.--extra-index-url is needed to resolve build dependencies only.

source path/to/installed/openvino/setupvars.sh git clone https://github.com/openvinotoolkit/openvino_tokenizers.git cd openvino_tokenizers pip install -e .[all] --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly

verify installation by running tests

cd tests/ pytest .

C++ Installation

You can use converted tokenizers in C++ pipelines with prebuild binaries.

  1. Download OpenVINO archive distribution for your OS from here and extract the archive.
  2. Download OpenVINO Tokenizers prebuild libraries from here. To ensure compatibility first three numbers of OpenVINO Tokenizers version should match OpenVINO version and OS.
  3. Extract OpenVINO Tokenizers archive into OpenVINO installation directory. OpenVINO Tokenizers archive maintains the structure to be aligned with OpenVINO archive:
    • Windows: <openvino_dir>\runtime\bin\intel64\Release\
    • MacOS_x86: <openvino_dir>/runtime/lib/intel64/Release
    • MacOS_arm64: <openvino_dir>/runtime/lib/arm64/Release/
    • Linux_x86: <openvino_dir>/runtime/lib/intel64/
    • Linux_arm64: <openvino_dir>/runtime/lib/aarch64/

After that you can add binary extension in the code with:

and read/compile converted (de)tokenizers models. If you use version 2023.3.0.0, the binary extension file is called (lib)user_ov_extension.(dll/dylib/so).

C++ Build

To build OpenVINO Tokenizers binaries locally, use this command:

source path/to/installed/openvino/setupvars.sh git clone https://github.com/openvinotoolkit/openvino_tokenizers.git cd openvino_tokenizers mkdir build && cd build cmake -DCMAKE_BUILD_TYPE=Release .. make

After that, you can transfer all binaries from build/src to <openvino_dir> as described in the C++ installation instruction above.

Usage

⚠️ OpenVINO Tokenizers can be inferred on a CPU device only.

Convert HuggingFace tokenizer

OpenVINO Tokenizers ships with CLI tool that can convert tokenizers from Huggingface Hub or Huggingface tokenizers saved on disk:

convert_tokenizer codellama/CodeLlama-7b-hf --with-detokenizer -o output_dir

There is also convert_tokenizer function that can convert tokenizer python object.

import numpy as np from transformers import AutoTokenizer from openvino import compile_model, save_model from openvino_tokenizers import convert_tokenizer

hf_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") ov_tokenizer = convert_tokenizer(hf_tokenizer)

compiled_tokenzier = compile_model(ov_tokenizer) text_input = ["Test string"]

hf_output = hf_tokenizer(text_input, return_tensors="np") ov_output = compiled_tokenzier(text_input)

for output_name in hf_output: print(f"OpenVINO {output_name} = {ov_output[output_name]}") print(f"HuggingFace {output_name} = {hf_output[output_name]}")

OpenVINO input_ids = [[ 101 3231 5164 102]]

HuggingFace input_ids = [[ 101 3231 5164 102]]

OpenVINO token_type_ids = [[0 0 0 0]]

HuggingFace token_type_ids = [[0 0 0 0]]

OpenVINO attention_mask = [[1 1 1 1]]

HuggingFace attention_mask = [[1 1 1 1]]

save tokenizer for later use

save_model(ov_tokenizer, "openvino_tokenizer.xml")

loaded_tokenizer = compile_model("openvino_tokenizer.xml") loaded_ov_output = loaded_tokenizer(text_input) for output_name in hf_output: assert np.all(loaded_ov_output[output_name] == ov_output[output_name])

Connect Tokenizer to a Model

To infer and convert the original model, install torch or torch-cpu to the virtual environment.

from transformers import AutoTokenizer, AutoModelForSequenceClassification from openvino import compile_model, convert_model from openvino_tokenizers import convert_tokenizer, connect_models

checkpoint = "mrm8488/bert-tiny-finetuned-sms-spam-detection" hf_tokenizer = AutoTokenizer.from_pretrained(checkpoint) hf_model = AutoModelForSequenceClassification.from_pretrained(checkpoint)

text_input = ["Free money!!!"] hf_input = hf_tokenizer(text_input, return_tensors="pt") hf_output = hf_model(**hf_input)

ov_tokenizer = convert_tokenizer(hf_tokenizer) ov_model = convert_model(hf_model, example_input=hf_input.data) combined_model = connect_models(ov_tokenizer, ov_model) compiled_combined_model = compile_model(combined_model)

openvino_output = compiled_combined_model(text_input)

print(f"OpenVINO logits: {openvino_output['logits']}")

OpenVINO logits: [[ 1.2007061 -1.4698029]]

print(f"HuggingFace logits {hf_output.logits}")

HuggingFace logits tensor([[ 1.2007, -1.4698]], grad_fn=)

Use Extension With Converted (De)Tokenizer or Model With (De)Tokenizer

Import openvino_tokenizers will register tokenizer-related operations to OpenVINO, after which you can work with saved tokenizers and detokenizers.

import numpy as np import openvino_tokenizers from openvino import Core

core = Core()

detokenizer from codellama sentencepiece model

compiled_detokenizer = core.compile_model("detokenizer.xml")

token_ids = np.random.randint(100, 1000, size=(3, 5)) openvino_output = compiled_detokenizer(token_ids)

print(openvino_output["string_output"])

['sc�ouition�', 'intvenord hasient', 'g shouldwer M more']

Text Generation Pipeline

import numpy as np from openvino import compile_model, convert_model from openvino_tokenizers import add_greedy_decoding, convert_tokenizer from transformers import AutoModelForCausalLM, AutoTokenizer

model_checkpoint = "JackFram/llama-68m" hf_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) hf_model = AutoModelForCausalLM.from_pretrained(model_checkpoint, use_cache=False)

convert hf tokenizer

text_input = ["Quick brown fox jumped "] ov_tokenizer, ov_detokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=True) compiled_tokenizer = compile_model(ov_tokenizer)

transform input text into tokens

ov_input = compiled_tokenizer(text_input) hf_input = hf_tokenizer(text_input, return_tensors="pt")

convert Pytorch model to OpenVINO IR and add greedy decoding pipeline to it

ov_model = convert_model(hf_model, example_input=hf_input.data) ov_model_with_greedy_decoding = add_greedy_decoding(ov_model) compiled_model = compile_model(ov_model_with_greedy_decoding)

generate new tokens

new_tokens_size = 10 prompt_size = ov_input["input_ids"].shape[-1] input_dict = { output.any_name: np.hstack([tensor, np.zeros(shape=(1, new_tokens_size), dtype=np.int_)]) for output, tensor in ov_input.items() } for idx in range(prompt_size, prompt_size + new_tokens_size): output = compiled_model(input_dict)["token_ids"] input_dict["input_ids"][:, idx] = output[:, idx - 1] input_dict["attention_mask"][:, idx] = 1 ov_token_ids = input_dict["input_ids"]

hf_token_ids = hf_model.generate( **hf_input, min_new_tokens=new_tokens_size, max_new_tokens=new_tokens_size, temperature=0, # greedy decoding )

decode model output

compiled_detokenizer = compile_model(ov_detokenizer) ov_output = compiled_detokenizer(ov_token_ids)["string_output"] hf_output = hf_tokenizer.batch_decode(hf_token_ids, skip_special_tokens=True) print(f"OpenVINO output string: {ov_output}")

OpenVINO output string: ['Quick brown fox was walking through the forest. He was looking for something']

print(f"HuggingFace output string: {hf_output}")

HuggingFace output string: ['Quick brown fox was walking through the forest. He was looking for something']

TensorFlow Text Integration

OpenVINO Tokenizers include converters for certain TensorFlow Text operations. Currently, only the MUSE model is supported. Here is an example of model conversion and inference:

import numpy as np import tensorflow_hub as hub import tensorflow_text # register tf text ops from openvino import convert_model, compile_model import openvino_tokenizers # register ov tokenizer ops and translators

sentences = ["dog", "I cuccioli sono carini.", "私は犬と一緒にビーチを散歩するのが好きです"] tf_embed = hub.load( "https://www.kaggle.com/models/google/universal-sentence-encoder/frameworks/" "TensorFlow2/variations/multilingual/versions/2" )

convert model that uses Sentencepiece tokenizer op from TF Text

ov_model = convert_model(tf_embed) ov_embed = compile_model(ov_model, "CPU")

ov_result = ov_embed(sentences)[ov_embed.output()] tf_result = tf_embed(sentences)

assert np.all(np.isclose(ov_result, tf_result, atol=1e-4))

RWKV Tokenizer

from urllib.request import urlopen

from openvino import compile_model from openvino_tokenizers import build_rwkv_tokenizer

rwkv_vocab_url = ( "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/tokenizer/rwkv_vocab_v20230424.txt" )

with urlopen(rwkv_vocab_url) as vocab_file: vocab = map(bytes.decode, vocab_file) tokenizer, detokenizer = build_rwkv_tokenizer(vocab)

tokenizer, detokenizer = compile_model(tokenizer), compile_model(detokenizer)

print(tokenized := tokenizer(["Test string"])["input_ids"]) # [[24235 47429]] print(detokenizer(tokenized)["string_output"]) # ['Test string']

Tokenizer From GGUF Model

from transformers import AutoTokenizer import openvino as ov from openvino_tokenizers import convert_tokenizer

model_id = "unsloth/DeepSeek-R1-Distill-Qwen-1.5B-GGUF" filename = "DeepSeek-R1-Distill-Qwen-1.5B-Q2_K.gguf" hf_tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename)

ov_tokenizer, ov_detokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=True) ov_tokenizer, ov_detokenizer = ov.compile_model(ov_tokenizer), ov.compile_model(ov_detokenizer)

print(ov_res := ov_tokenizer(["Test string"])["input_ids"]) # [[2271 914]] print(ov_detokenizer(ov_res)["string_output"]) # ['Test string']

C++ Usage Example

This example shows how to run inference with C++ on a text-classification model from Hugging Face. It expects the path to a model directory as parameter, and prints the logits returned by the model inference.

Export an example model by running the following command after pip install optimum[openvino]:

optimum-cli export openvino microsoft/deberta-base-mnli deberta-base-mnli-ov

#include <openvino/openvino.hpp> #include #include

int main(int argc, char* argv[]) { std::string dirname = argv[1]; std::filesystem::path dir_path(dirname); std::filesystem::path model_xml = dir_path / "openvino_model.xml"; std::filesystem::path tokenizer_xml = dir_path / "openvino_tokenizer.xml";

ov::Core core; // use "openvino_tokenizers.dll" on Windows, "libopenvino_tokenizers.dylib" on macOS core.add_extension("libopenvino_tokenizers.so");

ov::InferRequest tokenizer_request = core.compile_model(tokenizer_xml, "CPU").create_infer_request();

std::string prompt="Hello world!"; tokenizer_request.set_input_tensor(ov::Tensor{ov::element::string, {1}, &prompt}); tokenizer_request.infer(); ov::Tensor input_ids = tokenizer_request.get_tensor("input_ids"); ov::Tensor attention_mask = tokenizer_request.get_tensor("attention_mask");

ov::InferRequest infer_request = core.compile_model(model_xml, "CPU").create_infer_request(); infer_request.set_tensor("input_ids", input_ids); infer_request.set_tensor("attention_mask", attention_mask); infer_request.infer();

auto output = infer_request.get_tensor("logits"); const float *output_buffer = output.data();

size_t num_elements = output.get_size();

for (size_t i = 0; i < num_elements; i++) { std::cout << output_buffer[i] << " "; }

std::cout << std::endl; return 0; }

Unicode Support

Supported Tokenizer Types

Huggingface Tokenizer Type Tokenizer Model Type Tokenizer Detokenizer
Fast WordPiece
BPE
Unigram
WordLevel*
Legacy SentencePiece .model
Custom tiktoken
RWKV Trie

Test Results

This report is autogenerated and includes tokenizers and detokenizers tests. The Output Matched, % column shows the percent of test strings for which the results of OpenVINO and Huggingface Tokenizers are the same. To update the report run pytest --update_readme tokenizers_test.py in tests directory.

Output Match by Tokenizer Type

Tokenizer Type Output Matched, % Number of Tests
BPE 99.45 4397
SentencePiece 88.37 5279
Tiktoken 96.64 536
Unigram 95.35 1506
WordLevel 98.99 198
WordPiece 99.09 1319

Output Match by Model

Tokenizer Type Model Output Matched, % Number of Tests
BPE LiquidAI/LFM2-350M 100.00 253
BPE NousResearch/Llama-2-13b-hf 100.00 251
BPE NousResearch/Meta-Llama-3-8B-Instruct 100.00 253
BPE Qwen/Qwen3-Reranker-0.6B 100.00 269
BPE Xenova/gpt-4o 100.00 267
BPE answerdotai/ModernBERT-base 100.00 267
BPE bigscience/bloom 97.61 251
BPE deepseek-ai/DeepSeek-V3-0324 99.26 269
BPE deepseek-ai/deepseek-coder-6.7b-instruct 100.00 269
BPE facebook/galactica-120b 100.00 251
BPE koalajun/Gemma-2-9b-it-Ko-Crypto-Translate 100.00 253
BPE llava-hf/LLaVA-NeXT-Video-7B-hf 100.00 251
BPE microsoft/Phi-3-mini-128k-instruct 100.00 253
BPE microsoft/deberta-base 100.00 251
BPE mlx-community/quantized-gemma-7b-it 97.63 253
BPE roberta-base 100.00 267
BPE tiiuae/Falcon3-7B-Instruct 96.28 269
SentencePiece BAAI/bge-reranker-v2-m3 96.81 251
SentencePiece BAAI/bge-reranker-v2-m3_legacy 96.81 251
SentencePiece NousResearch/Llama-2-13b-hf 96.02 251
SentencePiece NousResearch/Llama-2-13b-hf_legacy 99.20 251
SentencePiece camembert-base 56.18 251
SentencePiece camembert-base_legacy 78.88 251
SentencePiece facebook/musicgen-small 82.07 251
SentencePiece facebook/musicgen-small_legacy 76.10 251
SentencePiece google/flan-t5-xxl 75.70 251
SentencePiece google/flan-t5-xxl_legacy 74.50 251
SentencePiece llava-hf/LLaVA-NeXT-Video-7B-hf 95.22 251
SentencePiece llava-hf/LLaVA-NeXT-Video-7B-hf_legacy 98.41 251
SentencePiece microsoft/Phi-3-mini-128k-instruct 99.21 253
SentencePiece microsoft/Phi-3-mini-128k-instruct_legacy 97.63 253
SentencePiece microsoft/deberta-v3-base 95.22 251
SentencePiece microsoft/deberta-v3-base_legacy 98.41 251
SentencePiece microsoft/speecht5_tts_legacy 71.71 251
SentencePiece mlx-community/quantized-gemma-7b-it 96.84 253
SentencePiece mlx-community/quantized-gemma-7b-it_legacy 97.63 253
SentencePiece rinna/bilingual-gpt-neox-4b 83.27 251
SentencePiece rinna/bilingual-gpt-neox-4b_legacy 89.64 251
Tiktoken Qwen/Qwen-14B-Chat 100.00 267
Tiktoken THUDM/glm-4-9b-chat 93.31 269
Unigram BAAI/bge-reranker-v2-m3 98.41 251
Unigram camembert-base 84.86 251
Unigram facebook/musicgen-small 98.41 251
Unigram google/flan-t5-xxl 92.03 251
Unigram microsoft/deberta-v3-base 98.41 251
Unigram rinna/bilingual-gpt-neox-4b 100.00 251
WordLevel cisco-ai/mini-bart-g2p 98.99 198
WordPiece bert-base-multilingual-cased 100.00 267
WordPiece cointegrated/rubert-tiny2 100.00 267
WordPiece google/mobilebert-uncased 100.00 251
WordPiece rasa/LaBSE 95.51 267
WordPiece sentence-transformers/all-MiniLM-L6-v2 100.00 267

Recreating Tokenizers From Tests

In some tokenizers, you need to select certain settings so that their output is closer to the Huggingface tokenizers: