GitHub - pmbaumgartner/setfit (original) (raw)
A scikit-learn API version of a SetFit classifier. Model originally developed by Moshe Wasserblat.
Use
from setfit import SetFitClassifier docs = ["yay", "boo", "yes", "no", "yeah"] labels = [1, 0, 1, 0, 1]
takes a sentence-transformers model
clf = SetFitClassifier("paraphrase-MiniLM-L3-v2")
fine-tunes embeddings + trains logistic regression head
clf.fit(docs, labels)
clf.predict(["affirmitive", "negative"]) array([1, 0])
Installation
pip install git+https://github.com/pmbaumgartner/setfit