initial commit

This commit is contained in:
Gokul
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
"""This is the langchain_ollama package.
Provides infrastructure for interacting with the [Ollama](https://ollama.com/)
service.
!!! note
**Newly added in 0.3.4:** `validate_model_on_init` param on all models.
This parameter allows you to validate the model exists in Ollama locally on
initialization. If set to `True`, it will raise an error if the model does not
exist locally. This is useful for ensuring that the model is available before
attempting to use it, especially in environments where models may not be
pre-downloaded.
"""
from importlib import metadata
from importlib.metadata import PackageNotFoundError
from langchain_ollama.chat_models import ChatOllama
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_ollama.llms import OllamaLLM
def _raise_package_not_found_error() -> None:
raise PackageNotFoundError
try:
if __package__ is None:
_raise_package_not_found_error()
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""
del metadata # optional, avoids polluting the results of dir(__package__)
__all__ = [
"ChatOllama",
"OllamaEmbeddings",
"OllamaLLM",
"__version__",
]

View File

@@ -0,0 +1,67 @@
"""Go from v1 content blocks to Ollama SDK format."""
from typing import Any
from langchain_core.messages import content as types
def _convert_from_v1_to_ollama(
content: list[types.ContentBlock],
model_provider: str | None, # noqa: ARG001
) -> list[dict[str, Any]]:
"""Convert v1 content blocks to Ollama format.
Args:
content: List of v1 `ContentBlock` objects.
model_provider: The model provider name that generated the v1 content.
Returns:
TODO
"""
new_content: list = []
for block in content:
if not isinstance(block, dict) or "type" not in block:
continue
block_dict = dict(block) # (For typing)
# TextContentBlock
if block_dict["type"] == "text":
# Note: this drops all other fields/extras
new_content.append({"type": "text", "text": block_dict["text"]})
# ReasoningContentBlock
# Ollama doesn't take reasoning back in
# In the future, could consider coercing into text as an option?
# e.g.:
# if block_dict["type"] == "reasoning":
# # Attempt to preserve content in text form
# new_content.append({"text": str(block_dict["reasoning"])})
# ImageContentBlock
if block_dict["type"] == "image":
# Already handled in _get_image_from_data_content_block
new_content.append(block_dict)
# TODO: AudioContentBlock once models support
# TODO: FileContentBlock once models support
# ToolCall -> ???
# if block_dict["type"] == "tool_call":
# function_call = {}
# new_content.append(function_call)
# ToolCallChunk -> ???
# elif block_dict["type"] == "tool_call_chunk":
# function_call = {}
# new_content.append(function_call)
# NonStandardContentBlock
if block_dict["type"] == "non_standard":
# Attempt to preserve content in text form
new_content.append(
{"type": "text", "text": str(block_dict.get("value", ""))}
)
return new_content

View File

@@ -0,0 +1,114 @@
"""Utility function to validate Ollama models."""
from __future__ import annotations
import base64
from urllib.parse import unquote, urlparse
from httpx import ConnectError
from ollama import Client, ResponseError
def validate_model(client: Client, model_name: str) -> None:
"""Validate that a model exists in the local Ollama instance.
Args:
client: The Ollama client.
model_name: The name of the model to validate.
Raises:
ValueError: If the model is not found or if there's a connection issue.
"""
try:
response = client.list()
model_names: list[str] = [model["model"] for model in response["models"]]
if not any(
model_name == m or m.startswith(f"{model_name}:") for m in model_names
):
msg = (
f"Model `{model_name}` not found in Ollama. Please pull the "
f"model (using `ollama pull {model_name}`) or specify a valid "
f"model name. Available local models: {', '.join(model_names)}"
)
raise ValueError(msg)
except ConnectError as e:
msg = (
"Failed to connect to Ollama. Please check that Ollama is downloaded, "
"running and accessible. https://ollama.com/download"
)
raise ValueError(msg) from e
except ResponseError as e:
msg = (
"Received an error from the Ollama API. "
"Please check your Ollama server logs."
)
raise ValueError(msg) from e
def parse_url_with_auth(
url: str | None,
) -> tuple[str | None, dict[str, str] | None]:
"""Parse URL and extract `userinfo` credentials for headers.
Handles URLs of the form: `https://user:password@host:port/path`
Args:
url: The URL to parse.
Returns:
A tuple of `(cleaned_url, headers_dict)` where:
- `cleaned_url` is the URL without authentication credentials if any were
found. Otherwise, returns the original URL.
- `headers_dict` contains Authorization header if credentials were found.
"""
if not url:
return None, None
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc or not parsed.hostname:
return None, None
if not parsed.username:
return url, None
# Handle case where password might be empty string or None
password = parsed.password or ""
# Create basic auth header (decode percent-encoding)
username = unquote(parsed.username)
password = unquote(password)
credentials = f"{username}:{password}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers = {"Authorization": f"Basic {encoded_credentials}"}
# Strip credentials from URL
cleaned_netloc = parsed.hostname or ""
if parsed.port:
cleaned_netloc += f":{parsed.port}"
cleaned_url = f"{parsed.scheme}://{cleaned_netloc}"
if parsed.path:
cleaned_url += parsed.path
if parsed.query:
cleaned_url += f"?{parsed.query}"
if parsed.fragment:
cleaned_url += f"#{parsed.fragment}"
return cleaned_url, headers
def merge_auth_headers(
client_kwargs: dict,
auth_headers: dict[str, str] | None,
) -> None:
"""Merge authentication headers into client kwargs in-place.
Args:
client_kwargs: The client kwargs dict to update.
auth_headers: Headers to merge (typically from `parse_url_with_auth`).
"""
if auth_headers:
headers = client_kwargs.get("headers", {})
headers.update(auth_headers)
client_kwargs["headers"] = headers

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,328 @@
"""Ollama embeddings models."""
from __future__ import annotations
from typing import Any
from langchain_core.embeddings import Embeddings
from ollama import AsyncClient, Client
from pydantic import BaseModel, ConfigDict, PrivateAttr, model_validator
from typing_extensions import Self
from ._utils import merge_auth_headers, parse_url_with_auth, validate_model
class OllamaEmbeddings(BaseModel, Embeddings):
"""Ollama embedding model integration.
Set up a local Ollama instance:
[Install the Ollama package](https://github.com/ollama/ollama) and set up a
local Ollama instance.
You will need to choose a model to serve.
You can view a list of available models via [the model library](https://ollama.com/library).
To fetch a model from the Ollama model library use `ollama pull <name-of-model>`.
For example, to pull the llama3 model:
```bash
ollama pull llama3
```
This will download the default tagged version of the model.
Typically, the default points to the latest, smallest sized-parameter model.
* On Mac, the models will be downloaded to `~/.ollama/models`
* On Linux (or WSL), the models will be stored at `/usr/share/ollama/.ollama/models`
You can specify the exact version of the model of interest
as such `ollama pull vicuna:13b-v1.5-16k-q4_0`.
To view pulled models:
```bash
ollama list
```
To start serving:
```bash
ollama serve
```
View the Ollama documentation for more commands.
```bash
ollama help
```
Install the `langchain-ollama` integration package:
```bash
pip install -U langchain_ollama
```
Key init args — completion params:
model: str
Name of Ollama model to use.
base_url: str | None
Base url the model is hosted under.
See full list of supported init args and their descriptions in the params section.
Instantiate:
```python
from langchain_ollama import OllamaEmbeddings
embed = OllamaEmbeddings(model="llama3")
```
Embed single text:
```python
input_text = "The meaning of life is 42"
vector = embed.embed_query(input_text)
print(vector[:3])
```
```python
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
```
Embed multiple texts:
```python
input_texts = ["Document 1...", "Document 2..."]
vectors = embed.embed_documents(input_texts)
print(len(vectors))
# The first 3 coordinates for the first vector
print(vectors[0][:3])
```
```python
2
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
```
Async:
```python
vector = await embed.aembed_query(input_text)
print(vector[:3])
# multiple:
# await embed.aembed_documents(input_texts)
```
```python
[-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188]
```
""" # noqa: E501
model: str
"""Model name to use."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in ollama locally on initialization.
!!! version-added "Added in `langchain-ollama` 0.3.4"
"""
base_url: str | None = None
"""Base url the model is hosted under.
If none, defaults to the Ollama client default.
Supports `userinfo` auth in the format `http://username:password@localhost:11434`.
Useful if your Ollama server is behind a proxy.
!!! warning
`userinfo` is not secure and should only be used for local testing or
in secure environments. Avoid using it in production or over unsecured
networks.
!!! note
If using `userinfo`, ensure that the Ollama server is configured to
accept and validate these credentials.
!!! note
`userinfo` headers are passed to both sync and async clients.
"""
client_kwargs: dict | None = {}
"""Additional kwargs to pass to the httpx clients. Pass headers in here.
These arguments are passed to both synchronous and async clients.
Use `sync_client_kwargs` and `async_client_kwargs` to pass different arguments
to synchronous and asynchronous clients.
"""
async_client_kwargs: dict | None = {}
"""Additional kwargs to merge with `client_kwargs` before passing to httpx client.
These are clients unique to the async client; for shared args use `client_kwargs`.
For a full list of the params, see the [httpx documentation](https://www.python-httpx.org/api/#asyncclient).
"""
sync_client_kwargs: dict | None = {}
"""Additional kwargs to merge with `client_kwargs` before passing to httpx client.
These are clients unique to the sync client; for shared args use `client_kwargs`.
For a full list of the params, see the [httpx documentation](https://www.python-httpx.org/api/#client).
"""
_client: Client | None = PrivateAttr(default=None)
"""The client to use for making requests."""
_async_client: AsyncClient | None = PrivateAttr(default=None)
"""The async client to use for making requests."""
mirostat: int | None = None
"""Enable Mirostat sampling for controlling perplexity.
(default: `0`, `0` = disabled, `1` = Mirostat, `2` = Mirostat 2.0)"""
mirostat_eta: float | None = None
"""Influences how quickly the algorithm responds to feedback
from the generated text. A lower learning rate will result in
slower adjustments, while a higher learning rate will make
the algorithm more responsive. (Default: `0.1`)"""
mirostat_tau: float | None = None
"""Controls the balance between coherence and diversity
of the output. A lower value will result in more focused and
coherent text. (Default: `5.0`)"""
num_ctx: int | None = None
"""Sets the size of the context window used to generate the
next token. (Default: `2048`) """
num_gpu: int | None = None
"""The number of GPUs to use. On macOS it defaults to `1` to
enable metal support, `0` to disable."""
keep_alive: int | None = None
"""Controls how long the model will stay loaded into memory
following the request (default: `5m`)
"""
num_thread: int | None = None
"""Sets the number of threads to use during computation.
By default, Ollama will detect this for optimal performance.
It is recommended to set this value to the number of physical
CPU cores your system has (as opposed to the logical number of cores)."""
repeat_last_n: int | None = None
"""Sets how far back for the model to look back to prevent
repetition. (Default: `64`, `0` = disabled, `-1` = `num_ctx`)"""
repeat_penalty: float | None = None
"""Sets how strongly to penalize repetitions. A higher value (e.g., `1.5`)
will penalize repetitions more strongly, while a lower value (e.g., `0.9`)
will be more lenient. (Default: `1.1`)"""
temperature: float | None = None
"""The temperature of the model. Increasing the temperature will
make the model answer more creatively. (Default: `0.8`)"""
stop: list[str] | None = None
"""Sets the stop tokens to use."""
tfs_z: float | None = None
"""Tail free sampling is used to reduce the impact of less probable
tokens from the output. A higher value (e.g., `2.0`) will reduce the
impact more, while a value of `1.0` disables this setting. (default: `1`)"""
top_k: int | None = None
"""Reduces the probability of generating nonsense. A higher value (e.g. `100`)
will give more diverse answers, while a lower value (e.g. `10`)
will be more conservative. (Default: `40`)"""
top_p: float | None = None
"""Works together with top-k. A higher value (e.g., `0.95`) will lead
to more diverse text, while a lower value (e.g., `0.5`) will
generate more focused and conservative text. (Default: `0.9`)"""
model_config = ConfigDict(
extra="forbid",
)
@property
def _default_params(self) -> dict[str, Any]:
"""Get the default parameters for calling Ollama."""
return {
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
"num_ctx": self.num_ctx,
"num_gpu": self.num_gpu,
"num_thread": self.num_thread,
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"stop": self.stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
}
@model_validator(mode="after")
def _set_clients(self) -> Self:
"""Set clients to use for Ollama."""
client_kwargs = self.client_kwargs or {}
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
merge_auth_headers(client_kwargs, auth_headers)
sync_client_kwargs = client_kwargs
if self.sync_client_kwargs:
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
async_client_kwargs = client_kwargs
if self.async_client_kwargs:
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
self._client = Client(host=cleaned_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
if not self._client:
msg = (
"Ollama client is not initialized. "
"Please ensure Ollama is running and the model is loaded."
)
raise ValueError(msg)
return self._client.embed(
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
)["embeddings"]
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
if not self._async_client:
msg = (
"Ollama client is not initialized. "
"Please ensure Ollama is running and the model is loaded."
)
raise ValueError(msg)
return (
await self._async_client.embed(
self.model,
texts,
options=self._default_params,
keep_alive=self.keep_alive,
)
)["embeddings"]
async def aembed_query(self, text: str) -> list[float]:
"""Embed query text."""
return (await self.aembed_documents([text]))[0]

View File

@@ -0,0 +1,545 @@
"""Ollama large language models."""
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import Any, Literal
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM, LangSmithParams
from langchain_core.outputs import GenerationChunk, LLMResult
from ollama import AsyncClient, Client, Options
from pydantic import PrivateAttr, model_validator
from typing_extensions import Self
from ._utils import merge_auth_headers, parse_url_with_auth, validate_model
class OllamaLLM(BaseLLM):
"""Ollama large language models.
Setup:
Install `langchain-ollama` and install/run the Ollama server locally:
```bash
pip install -U langchain-ollama
# Visit https://ollama.com/download to download and install Ollama
# (Linux users): start the server with `ollama serve`
```
Download a model to use:
```bash
ollama pull llama3.1
```
Key init args — generation params:
model: str
Name of the Ollama model to use (e.g. `'llama4'`).
temperature: float | None
Sampling temperature. Higher values make output more creative.
num_predict: int | None
Maximum number of tokens to predict.
top_k: int | None
Limits the next token selection to the K most probable tokens.
top_p: float | None
Nucleus sampling parameter. Higher values lead to more diverse text.
mirostat: int | None
Enable Mirostat sampling for controlling perplexity.
seed: int | None
Random number seed for generation reproducibility.
Key init args — client params:
base_url:
Base URL where Ollama server is hosted.
keep_alive:
How long the model stays loaded into memory.
format:
Specify the format of the output.
See full list of supported init args and their descriptions in the params section.
Instantiate:
```python
from langchain_ollama import OllamaLLM
model = OllamaLLM(
model="llama3.1",
temperature=0.7,
num_predict=256,
# base_url="http://localhost:11434",
# other params...
)
```
Invoke:
```python
input_text = "The meaning of life is "
response = model.invoke(input_text)
print(response)
```
```txt
"a philosophical question that has been contemplated by humans for
centuries..."
```
Stream:
```python
for chunk in model.stream(input_text):
print(chunk, end="")
```
```txt
a philosophical question that has been contemplated by humans for
centuries...
```
Async:
```python
response = await model.ainvoke(input_text)
# stream:
# async for chunk in model.astream(input_text):
# print(chunk, end="")
```
"""
model: str
"""Model name to use."""
reasoning: bool | None = None
"""Controls the reasoning/thinking mode for
[supported models](https://ollama.com/search?c=thinking).
- `True`: Enables reasoning mode. The model's reasoning process will be
captured and returned separately in the `additional_kwargs` of the
response message, under `reasoning_content`. The main response
content will not include the reasoning tags.
- `False`: Disables reasoning mode. The model will not perform any reasoning,
and the response will not include any reasoning content.
- `None` (Default): The model will use its default reasoning behavior. If
the model performs reasoning, the `<think>` and `</think>` tags will
be present directly within the main response content."""
validate_model_on_init: bool = False
"""Whether to validate the model exists in ollama locally on initialization.
!!! version-added "Added in `langchain-ollama` 0.3.4"
"""
mirostat: int | None = None
"""Enable Mirostat sampling for controlling perplexity.
(default: `0`, `0` = disabled, `1` = Mirostat, `2` = Mirostat 2.0)"""
mirostat_eta: float | None = None
"""Influences how quickly the algorithm responds to feedback
from the generated text. A lower learning rate will result in
slower adjustments, while a higher learning rate will make
the algorithm more responsive. (Default: `0.1`)"""
mirostat_tau: float | None = None
"""Controls the balance between coherence and diversity
of the output. A lower value will result in more focused and
coherent text. (Default: `5.0`)"""
num_ctx: int | None = None
"""Sets the size of the context window used to generate the
next token. (Default: `2048`)"""
num_gpu: int | None = None
"""The number of GPUs to use. On macOS it defaults to `1` to
enable metal support, `0` to disable."""
num_thread: int | None = None
"""Sets the number of threads to use during computation.
By default, Ollama will detect this for optimal performance.
It is recommended to set this value to the number of physical
CPU cores your system has (as opposed to the logical number of cores)."""
num_predict: int | None = None
"""Maximum number of tokens to predict when generating text.
(Default: `128`, `-1` = infinite generation, `-2` = fill context)"""
repeat_last_n: int | None = None
"""Sets how far back for the model to look back to prevent
repetition. (Default: `64`, `0` = disabled, `-1` = `num_ctx`)"""
repeat_penalty: float | None = None
"""Sets how strongly to penalize repetitions. A higher value (e.g., `1.5`)
will penalize repetitions more strongly, while a lower value (e.g., `0.9`)
will be more lenient. (Default: `1.1`)"""
temperature: float | None = None
"""The temperature of the model. Increasing the temperature will
make the model answer more creatively. (Default: `0.8`)"""
seed: int | None = None
"""Sets the random number seed to use for generation. Setting this
to a specific number will make the model generate the same text for
the same prompt."""
stop: list[str] | None = None
"""Sets the stop tokens to use."""
tfs_z: float | None = None
"""Tail free sampling is used to reduce the impact of less probable
tokens from the output. A higher value (e.g., `2.0`) will reduce the
impact more, while a value of 1.0 disables this setting. (default: `1`)"""
top_k: int | None = None
"""Reduces the probability of generating nonsense. A higher value (e.g. `100`)
will give more diverse answers, while a lower value (e.g. `10`)
will be more conservative. (Default: `40`)"""
top_p: float | None = None
"""Works together with top-k. A higher value (e.g., `0.95`) will lead
to more diverse text, while a lower value (e.g., `0.5`) will
generate more focused and conservative text. (Default: `0.9`)"""
format: Literal["", "json"] = ""
"""Specify the format of the output (options: `'json'`)"""
keep_alive: int | str | None = None
"""How long the model will stay loaded into memory."""
base_url: str | None = None
"""Base url the model is hosted under.
If none, defaults to the Ollama client default.
Supports `userinfo` auth in the format `http://username:password@localhost:11434`.
Useful if your Ollama server is behind a proxy.
!!! warning
`userinfo` is not secure and should only be used for local testing or
in secure environments. Avoid using it in production or over unsecured
networks.
!!! note
If using `userinfo`, ensure that the Ollama server is configured to
accept and validate these credentials.
!!! note
`userinfo` headers are passed to both sync and async clients.
"""
client_kwargs: dict | None = {}
"""Additional kwargs to pass to the httpx clients. Pass headers in here.
These arguments are passed to both synchronous and async clients.
Use `sync_client_kwargs` and `async_client_kwargs` to pass different arguments
to synchronous and asynchronous clients.
"""
async_client_kwargs: dict | None = {}
"""Additional kwargs to merge with `client_kwargs` before passing to httpx client.
These are clients unique to the async client; for shared args use `client_kwargs`.
For a full list of the params, see the [httpx documentation](https://www.python-httpx.org/api/#asyncclient).
"""
sync_client_kwargs: dict | None = {}
"""Additional kwargs to merge with `client_kwargs` before passing to httpx client.
These are clients unique to the sync client; for shared args use `client_kwargs`.
For a full list of the params, see the [httpx documentation](https://www.python-httpx.org/api/#client).
"""
_client: Client | None = PrivateAttr(default=None)
"""The client to use for making requests."""
_async_client: AsyncClient | None = PrivateAttr(default=None)
"""The async client to use for making requests."""
def _generate_params(
self,
prompt: str,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
if self.stop is not None and stop is not None:
msg = "`stop` found in both the input and default params."
raise ValueError(msg)
if self.stop is not None:
stop = self.stop
options_dict = kwargs.pop(
"options",
{
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
"num_ctx": self.num_ctx,
"num_gpu": self.num_gpu,
"num_thread": self.num_thread,
"num_predict": self.num_predict,
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"temperature": self.temperature,
"seed": self.seed,
"stop": self.stop if stop is None else stop,
"tfs_z": self.tfs_z,
"top_k": self.top_k,
"top_p": self.top_p,
},
)
return {
"prompt": prompt,
"stream": kwargs.pop("stream", True),
"model": kwargs.pop("model", self.model),
"think": kwargs.pop("reasoning", self.reasoning),
"format": kwargs.pop("format", self.format),
"options": Options(**options_dict),
"keep_alive": kwargs.pop("keep_alive", self.keep_alive),
**kwargs,
}
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "ollama-llm"
def _get_ls_params(
self, stop: list[str] | None = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
if max_tokens := kwargs.get("num_predict", self.num_predict):
params["ls_max_tokens"] = max_tokens
return params
@model_validator(mode="after")
def _set_clients(self) -> Self:
"""Set clients to use for ollama."""
client_kwargs = self.client_kwargs or {}
cleaned_url, auth_headers = parse_url_with_auth(self.base_url)
merge_auth_headers(client_kwargs, auth_headers)
sync_client_kwargs = client_kwargs
if self.sync_client_kwargs:
sync_client_kwargs = {**sync_client_kwargs, **self.sync_client_kwargs}
async_client_kwargs = client_kwargs
if self.async_client_kwargs:
async_client_kwargs = {**async_client_kwargs, **self.async_client_kwargs}
self._client = Client(host=cleaned_url, **sync_client_kwargs)
self._async_client = AsyncClient(host=cleaned_url, **async_client_kwargs)
if self.validate_model_on_init:
validate_model(self._client, self.model)
return self
async def _acreate_generate_stream(
self,
prompt: str,
stop: list[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[Mapping[str, Any] | str]:
if self._async_client:
async for part in await self._async_client.generate(
**self._generate_params(prompt, stop=stop, **kwargs)
):
yield part
def _create_generate_stream(
self,
prompt: str,
stop: list[str] | None = None,
**kwargs: Any,
) -> Iterator[Mapping[str, Any] | str]:
if self._client:
yield from self._client.generate(
**self._generate_params(prompt, stop=stop, **kwargs)
)
async def _astream_with_aggregation(
self,
prompt: str,
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
verbose: bool = False, # noqa: FBT002
**kwargs: Any,
) -> GenerationChunk:
final_chunk = None
thinking_content = ""
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
if stream_resp.get("thinking"):
thinking_content += stream_resp["thinking"]
chunk = GenerationChunk(
text=stream_resp.get("response", ""),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
),
)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=verbose,
)
if final_chunk is None:
msg = "No data received from Ollama stream."
raise ValueError(msg)
if thinking_content:
if final_chunk.generation_info:
final_chunk.generation_info["thinking"] = thinking_content
else:
final_chunk.generation_info = {"thinking": thinking_content}
return final_chunk
def _stream_with_aggregation(
self,
prompt: str,
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
verbose: bool = False, # noqa: FBT002
**kwargs: Any,
) -> GenerationChunk:
final_chunk = None
thinking_content = ""
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
if stream_resp.get("thinking"):
thinking_content += stream_resp["thinking"]
chunk = GenerationChunk(
text=stream_resp.get("response", ""),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
),
)
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=verbose,
)
if final_chunk is None:
msg = "No data received from Ollama stream."
raise ValueError(msg)
if thinking_content:
if final_chunk.generation_info:
final_chunk.generation_info["thinking"] = thinking_content
else:
final_chunk.generation_info = {"thinking": thinking_content}
return final_chunk
def _generate(
self,
prompts: list[str],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> LLMResult:
generations = []
for prompt in prompts:
final_chunk = self._stream_with_aggregation(
prompt,
stop=stop,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
)
generations.append([final_chunk])
return LLMResult(generations=generations) # type: ignore[arg-type]
async def _agenerate(
self,
prompts: list[str],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> LLMResult:
generations = []
for prompt in prompts:
final_chunk = await self._astream_with_aggregation(
prompt,
stop=stop,
run_manager=run_manager,
verbose=self.verbose,
**kwargs,
)
generations.append([final_chunk])
return LLMResult(generations=generations) # type: ignore[arg-type]
def _stream(
self,
prompt: str,
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
reasoning = kwargs.get("reasoning", self.reasoning)
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
additional_kwargs = {}
if reasoning and (thinking_content := stream_resp.get("thinking")):
additional_kwargs["reasoning_content"] = thinking_content
chunk = GenerationChunk(
text=(stream_resp.get("response", "")),
generation_info={
"finish_reason": self.stop,
**additional_kwargs,
**(
dict(stream_resp) if stream_resp.get("done") is True else {}
),
},
)
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk
async def _astream(
self,
prompt: str,
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
reasoning = kwargs.get("reasoning", self.reasoning)
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
if not isinstance(stream_resp, str):
additional_kwargs = {}
if reasoning and (thinking_content := stream_resp.get("thinking")):
additional_kwargs["reasoning_content"] = thinking_content
chunk = GenerationChunk(
text=(stream_resp.get("response", "")),
generation_info={
"finish_reason": self.stop,
**additional_kwargs,
**(
dict(stream_resp) if stream_resp.get("done") is True else {}
),
},
)
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk