Text classification
Few-shot text classification
Overview
Few-shot text classification is a task of classifying a text into one of the pre-defined classes based on a few examples of each class. For example, given a few examples of the class positive, negative, and neutral, the model should be able to classify a new text into one of these classes. This concept is sometimes called in-context-learning ([1]). This promises great results but may be very unstable ([2]).
Like Sciit-LLM, Scikit-Ollama does not select a subset of the training data, and instead use the entire training set to construct the examples. Therefore, if your training set is large, you might want to consider splitting it into training and validation sets, while keeping the training set small (we recommend not to exceed 10 examples per class).
Also keep in mind that the order of the examples may have an influence on model performance!
Example using llama3:
from skollama.models.ollama.classification.few_shot import (
FewShotOllamaClassifier,
MultiLabelFewShotOllamaClassifier,
)
from skllm.datasets import (
get_classification_dataset,
get_multilabel_classification_dataset,
)
# single label
X, y = get_classification_dataset()
clf = FewShotOllamaClassifier(model="llama3")
clf.fit(X,y)
labels = clf.predict(X)
# multi-label
X, y = get_multilabel_classification_dataset()
clf = MultiLabelFewShotOllamaClassifier(max_labels=2, model="llama3")
clf.fit(X,y)
labels = clf.predict(X)
API Reference
The following API reference only lists the parameters needed for the initialization of the estimator. The remaining methods follow the syntax of a scikit-learn classifier.
FewShotOllamaClassifier
from skollama.models.ollama.classification.few_shot import FewShotOllamaClassifier
Parameter | Type | Description |
---|---|---|
model | str, optional | Model to use, by default "llama3". |
host | str, optional | Ollama host to connect to, by default "http://localhost:11434". |
options | dict, optional | Additional options to pass to the Ollama API, by default None. |
default_label | str, optional | Default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random". |
prompt_template | Optional[str], optional | Custom prompt template to use, by default None. |
MultiLabelFewShotOllamaClassifier
from skollama.models.ollama.classification.few_shot import MultiLabelFewShotOllamaClassifier
Parameter | Type | Description |
---|---|---|
model | str, optional | Model to use, by default "llama3". |
host | str, optional | Ollama host to connect to, by default "http://localhost:11434". |
options | dict, optional | Additional options to pass to the Ollama API, by default None. |
default_label | str, optional | Default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random". |
max_labels | Optional[int], optional | Maximum number of labels to predict, by default 5. |
prompt_template | Optional[str], optional | Custom prompt template to use, by default None. |