Text classification
Zero-shot text classification
Overview
One of the powerful features of LLMs is the ability to perform text classification without being re-trained. For that, the only requirement is that the labels must be descriptive.
For example, let's consider a task of classifying a text into one of the following categories: [positive, negative, neutral]. We will use a class ZeroShotOllamaClassifier
and a regular scikit-learn API to perform the classification:
from skollama.models.ollama.classification.zero_shot import ZeroShotOllamaClassifier
from skllm.datasets import get_classification_dataset
# demo sentiment analysis dataset
# labels: positive, negative, neutral
X, y = get_classification_dataset()
clf = ZeroShotOllamaClassifier(model="llama3")
clf.fit(X,y)
labels = clf.predict(X)
However, in the zero-shot setting, the training data is not required as it is only used for the extraction of the candidate labels. Therefore, it is sufficient to manually provide a list of candidate labels:
from skollama.models.ollama.classification.zero_shot import ZeroShotOllamaClassifier
from skllm.datasets import get_classification_dataset
X, _ = get_classification_dataset()
clf = ZeroShotOllamaClassifier()
clf.fit(None, ["positive", "negative", "neutral"])
labels = clf.predict(X)
Additionally, it is possible to perform the classification in a multi-label setting, where multiple labels can be assigned to a single text at a same time:
from skllm.models.ollama.classification.zero_shot import MultiLabelZeroShotOllamaClassifier
from skllm.datasets import get_multilabel_classification_dataset
X, _ = get_multilabel_classification_dataset()
candidate_labels = [
"Quality",
"Price",
"Delivery",
"Service",
"Product Variety",
"Customer Support",
"Packaging",
"User Experience",
"Return Policy",
"Product Information",
]
clf = MultiLabelZeroShotOllamaClassifier(max_labels=3)
clf.fit(None, [candidate_labels])
labels = clf.predict(X)
Note
Unlike in a typical supervised setting, the performance of a zero-shot classifier greatly depends on how the label itself is structured. It has to be expressed in natural language, be descriptive and self-explanatory. For example, in the previous semantic classification task, it could be beneficial to transform a label from <<SEMANTICS>>
to the semantics of the provided text is <<SEMANTICS>>
.
API Reference
The following API reference only lists the parameters needed for the initialization of the estimator. The remaining methods follow the syntax of a scikit-learn classifier.
ZeroShotOllamaClassifier
from skllm.models.ollama.classification.zero_shot import ZeroShotOllamaClassifier
Parameter | Type | Description |
---|---|---|
model | str | Model to use, by default "llama3". |
host | str | Ollama host to connect to, by default "http://localhost:11434". |
options | dict | Additional options to pass to the Ollama API, by default None. |
default_label | str | Default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random". |
prompt_template | Optional[str] | Custom prompt template to use, by default None. |
MultiLabelZeroShotOllamaClassifier
from skllm.models.ollama.classification.zero_shot import MultiLabelZeroShotOllamaClassifier
Parameter | Type | Description |
---|---|---|
model | str | Model to use, by default "llama3". |
host | str | Ollama host to connect to, by default "http://localhost:11434". |
options | dict | Additional options to pass to the Ollama API, by default None. |
default_label | str | Default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random". |
max_labels | Optional[int] | Maximum labels per sample, by default 5. |
prompt_template | Optional[str] | Custom prompt template to use, by default None. |