Text classification

Few-shot text classification

Overview

Few-shot text classification is a task of classifying a text into one of the pre-defined classes based on a few examples of each class. For example, given a few examples of the class positive, negative, and neutral, the model should be able to classify a new text into one of these classes. This concept is sometimes called in-context-learning ([1]). This promises great results but may be very unstable ([2]).

Like Sciit-LLM, Scikit-Ollama does not select a subset of the training data, and instead use the entire training set to construct the examples. Therefore, if your training set is large, you might want to consider splitting it into training and validation sets, while keeping the training set small (we recommend not to exceed 10 examples per class).

Also keep in mind that the order of the examples may have an influence on model performance!

Example using llama3:

from skollama.models.ollama.classification.few_shot import (
FewShotOllamaClassifier,
MultiLabelFewShotOllamaClassifier,
)
from skllm.datasets import (
    get_classification_dataset,
    get_multilabel_classification_dataset,
)

# single label
X, y = get_classification_dataset()
clf = FewShotOllamaClassifier(model="llama3")
clf.fit(X,y)
labels = clf.predict(X)

# multi-label
X, y = get_multilabel_classification_dataset()
clf = MultiLabelFewShotOllamaClassifier(max_labels=2, model="llama3")
clf.fit(X,y)
labels = clf.predict(X)

API Reference

The following API reference only lists the parameters needed for the initialization of the estimator. The remaining methods follow the syntax of a scikit-learn classifier.

FewShotOllamaClassifier

from skollama.models.ollama.classification.few_shot import FewShotOllamaClassifier
ParameterTypeDescription
modelstr, optionalModel to use, by default "llama3".
hoststr, optionalOllama host to connect to, by default "http://localhost:11434".
optionsdict, optionalAdditional options to pass to the Ollama API, by default None.
default_labelstr, optionalDefault label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random".
prompt_templateOptional[str], optionalCustom prompt template to use, by default None.

MultiLabelFewShotOllamaClassifier

from skollama.models.ollama.classification.few_shot import MultiLabelFewShotOllamaClassifier
ParameterTypeDescription
modelstr, optionalModel to use, by default "llama3".
hoststr, optionalOllama host to connect to, by default "http://localhost:11434".
optionsdict, optionalAdditional options to pass to the Ollama API, by default None.
default_labelstr, optionalDefault label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random".
max_labelsOptional[int], optionalMaximum number of labels to predict, by default 5.
prompt_templateOptional[str], optionalCustom prompt template to use, by default None.
Previous
Zero-shot text classification