Text classification

Zero-shot text classification

Overview

One of the powerful features of LLMs is the ability to perform text classification without being re-trained. For that, the only requirement is that the labels must be descriptive.

For example, let's consider a task of classifying a text into one of the following categories: [positive, negative, neutral]. We will use a class ZeroShotOllamaClassifier and a regular scikit-learn API to perform the classification:

from skollama.models.ollama.classification.zero_shot import ZeroShotOllamaClassifier
from skllm.datasets import get_classification_dataset

# demo sentiment analysis dataset
# labels: positive, negative, neutral
X, y = get_classification_dataset()

clf = ZeroShotOllamaClassifier(model="llama3")
clf.fit(X,y)
labels = clf.predict(X)

However, in the zero-shot setting, the training data is not required as it is only used for the extraction of the candidate labels. Therefore, it is sufficient to manually provide a list of candidate labels:

from skollama.models.ollama.classification.zero_shot import ZeroShotOllamaClassifier
from skllm.datasets import get_classification_dataset

X, _ = get_classification_dataset()

clf = ZeroShotOllamaClassifier()
clf.fit(None, ["positive", "negative", "neutral"])
labels = clf.predict(X)

Additionally, it is possible to perform the classification in a multi-label setting, where multiple labels can be assigned to a single text at a same time:

from skllm.models.ollama.classification.zero_shot import MultiLabelZeroShotOllamaClassifier
from skllm.datasets import get_multilabel_classification_dataset

X, _ = get_multilabel_classification_dataset()
candidate_labels = [
    "Quality",
    "Price",
    "Delivery",
    "Service",
    "Product Variety",
    "Customer Support",
    "Packaging",
    "User Experience",
    "Return Policy",
    "Product Information",
]
clf = MultiLabelZeroShotOllamaClassifier(max_labels=3)
clf.fit(None, [candidate_labels])
labels = clf.predict(X)

Note

Unlike in a typical supervised setting, the performance of a zero-shot classifier greatly depends on how the label itself is structured. It has to be expressed in natural language, be descriptive and self-explanatory. For example, in the previous semantic classification task, it could be beneficial to transform a label from <<SEMANTICS>> to the semantics of the provided text is <<SEMANTICS>>.


API Reference

The following API reference only lists the parameters needed for the initialization of the estimator. The remaining methods follow the syntax of a scikit-learn classifier.

ZeroShotOllamaClassifier

from skllm.models.ollama.classification.zero_shot import ZeroShotOllamaClassifier
ParameterTypeDescription
modelstrModel to use, by default "llama3".
hoststrOllama host to connect to, by default "http://localhost:11434".
optionsdictAdditional options to pass to the Ollama API, by default None.
default_labelstrDefault label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random".
prompt_templateOptional[str]Custom prompt template to use, by default None.

MultiLabelZeroShotOllamaClassifier

from skllm.models.ollama.classification.zero_shot import MultiLabelZeroShotOllamaClassifier
ParameterTypeDescription
modelstrModel to use, by default "llama3".
hoststrOllama host to connect to, by default "http://localhost:11434".
optionsdictAdditional options to pass to the Ollama API, by default None.
default_labelstrDefault label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random".
max_labelsOptional[int]Maximum labels per sample, by default 5.
prompt_templateOptional[str]Custom prompt template to use, by default None.
Previous
Ollama setup