initial commit
This commit is contained in:
319
venv/Lib/site-packages/langchain_community/llms/llamafile.py
Normal file
319
venv/Lib/site-packages/langchain_community/llms/llamafile.py
Normal file
@@ -0,0 +1,319 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.outputs import GenerationChunk
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class Llamafile(LLM):
|
||||
"""Llamafile lets you distribute and run large language models with a
|
||||
single file.
|
||||
|
||||
To get started, see: https://github.com/Mozilla-Ocho/llamafile
|
||||
|
||||
To use this class, you will need to first:
|
||||
|
||||
1. Download a llamafile.
|
||||
2. Make the downloaded file executable: `chmod +x path/to/model.llamafile`
|
||||
3. Start the llamafile in server mode:
|
||||
|
||||
`./path/to/model.llamafile --server --nobrowser`
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import Llamafile
|
||||
llm = Llamafile()
|
||||
llm.invoke("Tell me a joke.")
|
||||
"""
|
||||
|
||||
base_url: str = "http://localhost:8080"
|
||||
"""Base url where the llamafile server is listening."""
|
||||
|
||||
request_timeout: Optional[int] = None
|
||||
"""Timeout for server requests"""
|
||||
|
||||
streaming: bool = False
|
||||
"""Allows receiving each predicted token in real-time instead of
|
||||
waiting for the completion to finish. To enable this, set to true."""
|
||||
|
||||
# Generation options
|
||||
|
||||
seed: int = -1
|
||||
"""Random Number Generator (RNG) seed. A random seed is used if this is
|
||||
less than zero. Default: -1"""
|
||||
|
||||
temperature: float = 0.8
|
||||
"""Temperature. Default: 0.8"""
|
||||
|
||||
top_k: int = 40
|
||||
"""Limit the next token selection to the K most probable tokens.
|
||||
Default: 40."""
|
||||
|
||||
top_p: float = 0.95
|
||||
"""Limit the next token selection to a subset of tokens with a cumulative
|
||||
probability above a threshold P. Default: 0.95."""
|
||||
|
||||
min_p: float = 0.05
|
||||
"""The minimum probability for a token to be considered, relative to
|
||||
the probability of the most likely token. Default: 0.05."""
|
||||
|
||||
n_predict: int = -1
|
||||
"""Set the maximum number of tokens to predict when generating text.
|
||||
Note: May exceed the set limit slightly if the last token is a partial
|
||||
multibyte character. When 0, no tokens will be generated but the prompt
|
||||
is evaluated into the cache. Default: -1 = infinity."""
|
||||
|
||||
n_keep: int = 0
|
||||
"""Specify the number of tokens from the prompt to retain when the
|
||||
context size is exceeded and tokens need to be discarded. By default,
|
||||
this value is set to 0 (meaning no tokens are kept). Use -1 to retain all
|
||||
tokens from the prompt."""
|
||||
|
||||
tfs_z: float = 1.0
|
||||
"""Enable tail free sampling with parameter z. Default: 1.0 = disabled."""
|
||||
|
||||
typical_p: float = 1.0
|
||||
"""Enable locally typical sampling with parameter p.
|
||||
Default: 1.0 = disabled."""
|
||||
|
||||
repeat_penalty: float = 1.1
|
||||
"""Control the repetition of token sequences in the generated text.
|
||||
Default: 1.1"""
|
||||
|
||||
repeat_last_n: int = 64
|
||||
"""Last n tokens to consider for penalizing repetition. Default: 64,
|
||||
0 = disabled, -1 = ctx-size."""
|
||||
|
||||
penalize_nl: bool = True
|
||||
"""Penalize newline tokens when applying the repeat penalty.
|
||||
Default: true."""
|
||||
|
||||
presence_penalty: float = 0.0
|
||||
"""Repeat alpha presence penalty. Default: 0.0 = disabled."""
|
||||
|
||||
frequency_penalty: float = 0.0
|
||||
"""Repeat alpha frequency penalty. Default: 0.0 = disabled"""
|
||||
|
||||
mirostat: int = 0
|
||||
"""Enable Mirostat sampling, controlling perplexity during text
|
||||
generation. 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0.
|
||||
Default: disabled."""
|
||||
|
||||
mirostat_tau: float = 5.0
|
||||
"""Set the Mirostat target entropy, parameter tau. Default: 5.0."""
|
||||
|
||||
mirostat_eta: float = 0.1
|
||||
"""Set the Mirostat learning rate, parameter eta. Default: 0.1."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "llamafile"
|
||||
|
||||
@property
|
||||
def _param_fieldnames(self) -> List[str]:
|
||||
# Return the list of fieldnames that will be passed as configurable
|
||||
# generation options to the llamafile server. Exclude 'builtin' fields
|
||||
# from the BaseLLM class like 'metadata' as well as fields that should
|
||||
# not be passed in requests (base_url, request_timeout).
|
||||
ignore_keys = [
|
||||
"base_url",
|
||||
"cache",
|
||||
"callback_manager",
|
||||
"callbacks",
|
||||
"metadata",
|
||||
"name",
|
||||
"request_timeout",
|
||||
"streaming",
|
||||
"tags",
|
||||
"verbose",
|
||||
"custom_get_token_ids",
|
||||
]
|
||||
attrs = [
|
||||
k for k in get_pydantic_field_names(self.__class__) if k not in ignore_keys
|
||||
]
|
||||
return attrs
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
params = {}
|
||||
for fieldname in self._param_fieldnames:
|
||||
params[fieldname] = getattr(self, fieldname)
|
||||
return params
|
||||
|
||||
def _get_parameters(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
params = self._default_params
|
||||
|
||||
# Only update keys that are already present in the default config.
|
||||
# This way, we don't accidentally post unknown/unhandled key/values
|
||||
# in the request to the llamafile server
|
||||
for k, v in kwargs.items():
|
||||
if k in params:
|
||||
params[k] = v
|
||||
|
||||
if stop is not None and len(stop) > 0:
|
||||
params["stop"] = stop
|
||||
|
||||
if self.streaming:
|
||||
params["stream"] = True
|
||||
|
||||
return params
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Request prompt completion from the llamafile server and return the
|
||||
output.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
stop: A list of strings to stop generation when encountered.
|
||||
run_manager:
|
||||
**kwargs: Any additional options to pass as part of the
|
||||
generation request.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
"""
|
||||
|
||||
if self.streaming:
|
||||
with StringIO() as buff:
|
||||
for chunk in self._stream(
|
||||
prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
buff.write(chunk.text)
|
||||
|
||||
text = buff.getvalue()
|
||||
|
||||
return text
|
||||
|
||||
else:
|
||||
params = self._get_parameters(stop=stop, **kwargs)
|
||||
payload = {"prompt": prompt, **params}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url=f"{self.base_url}/completion",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=payload,
|
||||
stream=False,
|
||||
timeout=self.request_timeout,
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise requests.exceptions.ConnectionError(
|
||||
f"Could not connect to Llamafile server. Please make sure "
|
||||
f"that a server is running at {self.base_url}."
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
response.encoding = "utf-8"
|
||||
|
||||
text = response.json()["content"]
|
||||
|
||||
return text
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Yields results objects as they are generated in real time.
|
||||
|
||||
It also calls the callback manager's on_llm_new_token event with
|
||||
similar parameters to the OpenAI LLM class method of the same name.
|
||||
|
||||
Args:
|
||||
prompt: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
run_manager:
|
||||
**kwargs: Any additional options to pass as part of the
|
||||
generation request.
|
||||
|
||||
Returns:
|
||||
A generator representing the stream of tokens being generated.
|
||||
|
||||
Yields:
|
||||
Dictionary-like objects each containing a token
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import Llamafile
|
||||
llm = Llamafile(
|
||||
temperature = 0.0
|
||||
)
|
||||
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
|
||||
stop=["'","\n"]):
|
||||
result = chunk["choices"][0]
|
||||
print(result["text"], end='', flush=True)
|
||||
|
||||
"""
|
||||
params = self._get_parameters(stop=stop, **kwargs)
|
||||
if "stream" not in params:
|
||||
params["stream"] = True
|
||||
|
||||
payload = {"prompt": prompt, **params}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url=f"{self.base_url}/completion",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=self.request_timeout,
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise requests.exceptions.ConnectionError(
|
||||
f"Could not connect to Llamafile server. Please make sure "
|
||||
f"that a server is running at {self.base_url}."
|
||||
)
|
||||
|
||||
response.encoding = "utf8"
|
||||
|
||||
for raw_chunk in response.iter_lines(decode_unicode=True):
|
||||
content = self._get_chunk_content(raw_chunk)
|
||||
chunk = GenerationChunk(text=content)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token=chunk.text)
|
||||
yield chunk
|
||||
|
||||
def _get_chunk_content(self, chunk: str) -> str:
|
||||
"""When streaming is turned on, llamafile server returns lines like:
|
||||
|
||||
'data: {"content":" They","multimodal":true,"slot_id":0,"stop":false}'
|
||||
|
||||
Here, we convert this to a dict and return the value of the 'content'
|
||||
field
|
||||
"""
|
||||
|
||||
if chunk.startswith("data:"):
|
||||
cleaned = chunk.lstrip("data: ")
|
||||
data = json.loads(cleaned)
|
||||
return data["content"]
|
||||
else:
|
||||
return chunk
|
||||
Reference in New Issue
Block a user