initial commit

This commit is contained in:
2026-05-11 12:36:20 +05:30
commit 384cbe8019
15377 changed files with 2360544 additions and 0 deletions

View File

@@ -0,0 +1,323 @@
"""**Utilities** are the integrations with third-part systems and packages.
Other LangChain classes use **Utilities** to interact with third-part systems
and packages.
"""
import importlib
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langchain_community.utilities.alpha_vantage import (
AlphaVantageAPIWrapper,
)
from langchain_community.utilities.apify import (
ApifyWrapper,
)
from langchain_community.utilities.arcee import (
ArceeWrapper,
)
from langchain_community.utilities.arxiv import (
ArxivAPIWrapper,
)
from langchain_community.utilities.asknews import (
AskNewsAPIWrapper,
)
from langchain_community.utilities.awslambda import (
LambdaWrapper,
)
from langchain_community.utilities.bibtex import (
BibtexparserWrapper,
)
from langchain_community.utilities.bing_search import (
BingSearchAPIWrapper,
)
from langchain_community.utilities.brave_search import (
BraveSearchWrapper,
)
from langchain_community.utilities.dataherald import DataheraldAPIWrapper
from langchain_community.utilities.dria_index import (
DriaAPIWrapper,
)
from langchain_community.utilities.duckduckgo_search import (
DuckDuckGoSearchAPIWrapper,
)
from langchain_community.utilities.golden_query import (
GoldenQueryAPIWrapper,
)
from langchain_community.utilities.google_books import (
GoogleBooksAPIWrapper,
)
from langchain_community.utilities.google_finance import (
GoogleFinanceAPIWrapper,
)
from langchain_community.utilities.google_jobs import (
GoogleJobsAPIWrapper,
)
from langchain_community.utilities.google_lens import (
GoogleLensAPIWrapper,
)
from langchain_community.utilities.google_places_api import (
GooglePlacesAPIWrapper,
)
from langchain_community.utilities.google_scholar import (
GoogleScholarAPIWrapper,
)
from langchain_community.utilities.google_search import (
GoogleSearchAPIWrapper,
)
from langchain_community.utilities.google_serper import (
GoogleSerperAPIWrapper,
)
from langchain_community.utilities.google_trends import (
GoogleTrendsAPIWrapper,
)
from langchain_community.utilities.graphql import (
GraphQLAPIWrapper,
)
from langchain_community.utilities.infobip import (
InfobipAPIWrapper,
)
from langchain_community.utilities.jira import (
JiraAPIWrapper,
)
from langchain_community.utilities.max_compute import (
MaxComputeAPIWrapper,
)
from langchain_community.utilities.merriam_webster import (
MerriamWebsterAPIWrapper,
)
from langchain_community.utilities.metaphor_search import (
MetaphorSearchAPIWrapper,
)
from langchain_community.utilities.mojeek_search import (
MojeekSearchAPIWrapper,
)
from langchain_community.utilities.nasa import (
NasaAPIWrapper,
)
from langchain_community.utilities.nvidia_riva import (
AudioStream,
NVIDIARivaASR,
NVIDIARivaStream,
NVIDIARivaTTS,
RivaASR,
RivaTTS,
)
from langchain_community.utilities.openweathermap import (
OpenWeatherMapAPIWrapper,
)
from langchain_community.utilities.oracleai import (
OracleSummary,
)
from langchain_community.utilities.outline import (
OutlineAPIWrapper,
)
from langchain_community.utilities.passio_nutrition_ai import (
NutritionAIAPI,
)
from langchain_community.utilities.portkey import (
Portkey,
)
from langchain_community.utilities.powerbi import (
PowerBIDataset,
)
from langchain_community.utilities.pubmed import (
PubMedAPIWrapper,
)
from langchain_community.utilities.rememberizer import RememberizerAPIWrapper
from langchain_community.utilities.requests import (
Requests,
RequestsWrapper,
TextRequestsWrapper,
)
from langchain_community.utilities.scenexplain import (
SceneXplainAPIWrapper,
)
from langchain_community.utilities.searchapi import (
SearchApiAPIWrapper,
)
from langchain_community.utilities.searx_search import (
SearxSearchWrapper,
)
from langchain_community.utilities.serpapi import (
SerpAPIWrapper,
)
from langchain_community.utilities.spark_sql import (
SparkSQL,
)
from langchain_community.utilities.sql_database import (
SQLDatabase,
)
from langchain_community.utilities.stackexchange import (
StackExchangeAPIWrapper,
)
from langchain_community.utilities.steam import (
SteamWebAPIWrapper,
)
from langchain_community.utilities.tensorflow_datasets import (
TensorflowDatasets,
)
from langchain_community.utilities.twilio import (
TwilioAPIWrapper,
)
from langchain_community.utilities.wikipedia import (
WikipediaAPIWrapper,
)
from langchain_community.utilities.wolfram_alpha import (
WolframAlphaAPIWrapper,
)
from langchain_community.utilities.you import (
YouSearchAPIWrapper,
)
from langchain_community.utilities.zapier import (
ZapierNLAWrapper,
)
__all__ = [
"AlphaVantageAPIWrapper",
"ApifyWrapper",
"ArceeWrapper",
"ArxivAPIWrapper",
"AskNewsAPIWrapper",
"AudioStream",
"BibtexparserWrapper",
"BingSearchAPIWrapper",
"BraveSearchWrapper",
"DataheraldAPIWrapper",
"DriaAPIWrapper",
"DuckDuckGoSearchAPIWrapper",
"GoldenQueryAPIWrapper",
"GoogleBooksAPIWrapper",
"GoogleFinanceAPIWrapper",
"GoogleJobsAPIWrapper",
"GoogleLensAPIWrapper",
"GooglePlacesAPIWrapper",
"GoogleScholarAPIWrapper",
"GoogleSearchAPIWrapper",
"GoogleSerperAPIWrapper",
"GoogleTrendsAPIWrapper",
"GraphQLAPIWrapper",
"InfobipAPIWrapper",
"JiraAPIWrapper",
"LambdaWrapper",
"MaxComputeAPIWrapper",
"MerriamWebsterAPIWrapper",
"MetaphorSearchAPIWrapper",
"MojeekSearchAPIWrapper",
"NVIDIARivaASR",
"NVIDIARivaStream",
"NVIDIARivaTTS",
"NasaAPIWrapper",
"NutritionAIAPI",
"OpenWeatherMapAPIWrapper",
"OracleSummary",
"OutlineAPIWrapper",
"Portkey",
"PowerBIDataset",
"PubMedAPIWrapper",
"RememberizerAPIWrapper",
"Requests",
"RequestsWrapper",
"RivaASR",
"RivaTTS",
"SceneXplainAPIWrapper",
"SearchApiAPIWrapper",
"SQLDatabase",
"SearxSearchWrapper",
"SerpAPIWrapper",
"SparkSQL",
"StackExchangeAPIWrapper",
"SteamWebAPIWrapper",
"TensorflowDatasets",
"TextRequestsWrapper",
"TwilioAPIWrapper",
"WikipediaAPIWrapper",
"WolframAlphaAPIWrapper",
"YouSearchAPIWrapper",
"ZapierNLAWrapper",
]
_module_lookup = {
"AlphaVantageAPIWrapper": "langchain_community.utilities.alpha_vantage",
"ApifyWrapper": "langchain_community.utilities.apify",
"ArceeWrapper": "langchain_community.utilities.arcee",
"ArxivAPIWrapper": "langchain_community.utilities.arxiv",
"AskNewsAPIWrapper": "langchain_community.utilities.asknews",
"AudioStream": "langchain_community.utilities.nvidia_riva",
"BibtexparserWrapper": "langchain_community.utilities.bibtex",
"BingSearchAPIWrapper": "langchain_community.utilities.bing_search",
"BraveSearchWrapper": "langchain_community.utilities.brave_search",
"DataheraldAPIWrapper": "langchain_community.utilities.dataherald",
"DriaAPIWrapper": "langchain_community.utilities.dria_index",
"DuckDuckGoSearchAPIWrapper": "langchain_community.utilities.duckduckgo_search",
"GoldenQueryAPIWrapper": "langchain_community.utilities.golden_query",
"GoogleBooksAPIWrapper": "langchain_community.utilities.google_books",
"GoogleFinanceAPIWrapper": "langchain_community.utilities.google_finance",
"GoogleJobsAPIWrapper": "langchain_community.utilities.google_jobs",
"GoogleLensAPIWrapper": "langchain_community.utilities.google_lens",
"GooglePlacesAPIWrapper": "langchain_community.utilities.google_places_api",
"GoogleScholarAPIWrapper": "langchain_community.utilities.google_scholar",
"GoogleSearchAPIWrapper": "langchain_community.utilities.google_search",
"GoogleSerperAPIWrapper": "langchain_community.utilities.google_serper",
"GoogleTrendsAPIWrapper": "langchain_community.utilities.google_trends",
"GraphQLAPIWrapper": "langchain_community.utilities.graphql",
"InfobipAPIWrapper": "langchain_community.utilities.infobip",
"JiraAPIWrapper": "langchain_community.utilities.jira",
"LambdaWrapper": "langchain_community.utilities.awslambda",
"MaxComputeAPIWrapper": "langchain_community.utilities.max_compute",
"MerriamWebsterAPIWrapper": "langchain_community.utilities.merriam_webster",
"MetaphorSearchAPIWrapper": "langchain_community.utilities.metaphor_search",
"MojeekSearchAPIWrapper": "langchain_community.utilities.mojeek_search",
"NVIDIARivaASR": "langchain_community.utilities.nvidia_riva",
"NVIDIARivaStream": "langchain_community.utilities.nvidia_riva",
"NVIDIARivaTTS": "langchain_community.utilities.nvidia_riva",
"NasaAPIWrapper": "langchain_community.utilities.nasa",
"NutritionAIAPI": "langchain_community.utilities.passio_nutrition_ai",
"OpenWeatherMapAPIWrapper": "langchain_community.utilities.openweathermap",
"OracleSummary": "langchain_community.utilities.oracleai",
"OutlineAPIWrapper": "langchain_community.utilities.outline",
"Portkey": "langchain_community.utilities.portkey",
"PowerBIDataset": "langchain_community.utilities.powerbi",
"PubMedAPIWrapper": "langchain_community.utilities.pubmed",
"RememberizerAPIWrapper": "langchain_community.utilities.rememberizer",
"Requests": "langchain_community.utilities.requests",
"RequestsWrapper": "langchain_community.utilities.requests",
"RivaASR": "langchain_community.utilities.nvidia_riva",
"RivaTTS": "langchain_community.utilities.nvidia_riva",
"SQLDatabase": "langchain_community.utilities.sql_database",
"SceneXplainAPIWrapper": "langchain_community.utilities.scenexplain",
"SearchApiAPIWrapper": "langchain_community.utilities.searchapi",
"SearxSearchWrapper": "langchain_community.utilities.searx_search",
"SerpAPIWrapper": "langchain_community.utilities.serpapi",
"SparkSQL": "langchain_community.utilities.spark_sql",
"StackExchangeAPIWrapper": "langchain_community.utilities.stackexchange",
"SteamWebAPIWrapper": "langchain_community.utilities.steam",
"TensorflowDatasets": "langchain_community.utilities.tensorflow_datasets",
"TextRequestsWrapper": "langchain_community.utilities.requests",
"TwilioAPIWrapper": "langchain_community.utilities.twilio",
"WikipediaAPIWrapper": "langchain_community.utilities.wikipedia",
"WolframAlphaAPIWrapper": "langchain_community.utilities.wolfram_alpha",
"YouSearchAPIWrapper": "langchain_community.utilities.you",
"ZapierNLAWrapper": "langchain_community.utilities.zapier",
}
REMOVED = {
"PythonREPL": (
"PythonREPL has been deprecated from langchain_community "
"due to being flagged by security scanners. See: "
"https://github.com/langchain-ai/langchain/issues/14345 "
"If you need to use it, please use the version "
"from langchain_experimental. "
"from langchain_experimental.utilities.python import PythonREPL."
)
}
def __getattr__(name: str) -> Any:
if name in REMOVED:
raise AssertionError(REMOVED[name])
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")

View File

@@ -0,0 +1,176 @@
"""Util that calls AlphaVantage for Currency Exchange Rate."""
from typing import Any, Dict, List, Optional
import requests
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, ConfigDict, model_validator
class AlphaVantageAPIWrapper(BaseModel):
"""Wrapper for AlphaVantage API for Currency Exchange Rate.
Docs for using:
1. Go to AlphaVantage and sign up for an API key
2. Save your API KEY into ALPHAVANTAGE_API_KEY env variable
"""
alphavantage_api_key: Optional[str] = None
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key exists in environment."""
values["alphavantage_api_key"] = get_from_dict_or_env(
values, "alphavantage_api_key", "ALPHAVANTAGE_API_KEY"
)
return values
def search_symbols(self, keywords: str) -> Dict[str, Any]:
"""Make a request to the AlphaVantage API to search for symbols."""
response = requests.get(
"https://www.alphavantage.co/query/",
params={
"function": "SYMBOL_SEARCH",
"keywords": keywords,
"apikey": self.alphavantage_api_key,
},
)
response.raise_for_status()
data = response.json()
if "Error Message" in data:
raise ValueError(f"API Error: {data['Error Message']}")
return data
def _get_market_news_sentiment(self, symbol: str) -> Dict[str, Any]:
"""Make a request to the AlphaVantage API to get market news sentiment for a
given symbol."""
response = requests.get(
"https://www.alphavantage.co/query/",
params={
"function": "NEWS_SENTIMENT",
"symbol": symbol,
"apikey": self.alphavantage_api_key,
},
)
response.raise_for_status()
data = response.json()
if "Error Message" in data:
raise ValueError(f"API Error: {data['Error Message']}")
return data
def _get_time_series_daily(self, symbol: str) -> Dict[str, Any]:
"""Make a request to the AlphaVantage API to get the daily time series."""
response = requests.get(
"https://www.alphavantage.co/query/",
params={
"function": "TIME_SERIES_DAILY",
"symbol": symbol,
"apikey": self.alphavantage_api_key,
},
)
response.raise_for_status()
data = response.json()
if "Error Message" in data:
raise ValueError(f"API Error: {data['Error Message']}")
return data
def _get_quote_endpoint(self, symbol: str) -> Dict[str, Any]:
"""Make a request to the AlphaVantage API to get the
latest price and volume information."""
response = requests.get(
"https://www.alphavantage.co/query/",
params={
"function": "GLOBAL_QUOTE",
"symbol": symbol,
"apikey": self.alphavantage_api_key,
},
)
response.raise_for_status()
data = response.json()
if "Error Message" in data:
raise ValueError(f"API Error: {data['Error Message']}")
return data
def _get_time_series_weekly(self, symbol: str) -> Dict[str, Any]:
"""Make a request to the AlphaVantage API
to get the Weekly Time Series."""
response = requests.get(
"https://www.alphavantage.co/query/",
params={
"function": "TIME_SERIES_WEEKLY",
"symbol": symbol,
"apikey": self.alphavantage_api_key,
},
)
response.raise_for_status()
data = response.json()
if "Error Message" in data:
raise ValueError(f"API Error: {data['Error Message']}")
return data
def _get_top_gainers_losers(self) -> Dict[str, Any]:
"""Make a request to the AlphaVantage API to get the top gainers, losers,
and most actively traded tickers in the US market."""
response = requests.get(
"https://www.alphavantage.co/query/",
params={
"function": "TOP_GAINERS_LOSERS",
"apikey": self.alphavantage_api_key,
},
)
response.raise_for_status()
data = response.json()
if "Error Message" in data:
raise ValueError(f"API Error: {data['Error Message']}")
return data
def _get_exchange_rate(
self, from_currency: str, to_currency: str
) -> Dict[str, Any]:
"""Make a request to the AlphaVantage API to get the exchange rate."""
response = requests.get(
"https://www.alphavantage.co/query/",
params={
"function": "CURRENCY_EXCHANGE_RATE",
"from_currency": from_currency,
"to_currency": to_currency,
"apikey": self.alphavantage_api_key,
},
)
response.raise_for_status()
data = response.json()
if "Error Message" in data:
raise ValueError(f"API Error: {data['Error Message']}")
return data
@property
def standard_currencies(self) -> List[str]:
return ["USD", "EUR", "GBP", "JPY", "CHF", "CAD", "AUD", "NZD"]
def run(self, from_currency: str, to_currency: str) -> str:
"""Get the current exchange rate for a specified currency pair."""
if to_currency not in self.standard_currencies:
from_currency, to_currency = to_currency, from_currency
data = self._get_exchange_rate(from_currency, to_currency)
return data["Realtime Currency Exchange Rate"]

View File

@@ -0,0 +1,27 @@
from typing import Any, List
def _get_anthropic_client() -> Any:
try:
import anthropic
except ImportError:
raise ImportError(
"Could not import anthropic python package. "
"This is needed in order to accurately tokenize the text "
"for anthropic models. Please install it with `pip install anthropic`."
)
return anthropic.Anthropic()
def get_num_tokens_anthropic(text: str) -> int:
"""Get the number of tokens in a string of text."""
client = _get_anthropic_client()
return client.count_tokens(text=text)
def get_token_ids_anthropic(text: str) -> List[int]:
"""Get the token ids for a string of text."""
client = _get_anthropic_client()
tokenizer = client.get_tokenizer()
encoded_text = tokenizer.encode(text)
return encoded_text.ids

View File

@@ -0,0 +1,227 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from langchain_core._api import deprecated
from langchain_core.documents import Document
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, model_validator
if TYPE_CHECKING:
from langchain_community.document_loaders import ApifyDatasetLoader
@deprecated(
since="0.3.18",
message=(
"This class is deprecated and will be removed in a future version. "
"You can swap to using the `ApifyWrapper`"
" implementation in `langchain_apify` package. "
"See <https://github.com/apify/langchain-apify>"
),
alternative_import="langchain_apify.ApifyWrapper",
)
class ApifyWrapper(BaseModel):
"""Wrapper around Apify.
To use, you should have the ``apify-client`` python package installed,
and the environment variable ``APIFY_API_TOKEN`` set with your API key, or pass
`apify_api_token` as a named parameter to the constructor.
"""
apify_client: Any
apify_client_async: Any
apify_api_token: Optional[str] = None
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate environment.
Validate that an Apify API token is set and the apify-client
Python package exists in the current environment.
"""
apify_api_token = get_from_dict_or_env(
values, "apify_api_token", "APIFY_API_TOKEN"
)
try:
from apify_client import ApifyClient, ApifyClientAsync
client = ApifyClient(apify_api_token)
if httpx_client := getattr(client.http_client, "httpx_client"):
httpx_client.headers["user-agent"] += "; Origin/langchain"
async_client = ApifyClientAsync(apify_api_token)
if httpx_async_client := getattr(
async_client.http_client, "httpx_async_client"
):
httpx_async_client.headers["user-agent"] += "; Origin/langchain"
values["apify_client"] = client
values["apify_client_async"] = async_client
except ImportError:
raise ImportError(
"Could not import apify-client Python package. "
"Please install it with `pip install apify-client`."
)
return values
def call_actor(
self,
actor_id: str,
run_input: Dict,
dataset_mapping_function: Callable[[Dict], Document],
*,
build: Optional[str] = None,
memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None,
) -> "ApifyDatasetLoader":
"""Run an Actor on the Apify platform and wait for results to be ready.
Args:
actor_id (str): The ID or name of the Actor on the Apify platform.
run_input (Dict): The input object of the Actor that you're trying to run.
dataset_mapping_function (Callable): A function that takes a single
dictionary (an Apify dataset item) and converts it to an
instance of the Document class.
build (str, optional): Optionally specifies the actor build to run.
It can be either a build tag or build number.
memory_mbytes (int, optional): Optional memory limit for the run,
in megabytes.
timeout_secs (int, optional): Optional timeout for the run, in seconds.
Returns:
ApifyDatasetLoader: A loader that will fetch the records from the
Actor run's default dataset.
"""
from langchain_community.document_loaders import ApifyDatasetLoader
actor_call = self.apify_client.actor(actor_id).call(
run_input=run_input,
build=build,
memory_mbytes=memory_mbytes,
timeout_secs=timeout_secs,
)
return ApifyDatasetLoader(
dataset_id=actor_call["defaultDatasetId"],
dataset_mapping_function=dataset_mapping_function,
)
async def acall_actor(
self,
actor_id: str,
run_input: Dict,
dataset_mapping_function: Callable[[Dict], Document],
*,
build: Optional[str] = None,
memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None,
) -> "ApifyDatasetLoader":
"""Run an Actor on the Apify platform and wait for results to be ready.
Args:
actor_id (str): The ID or name of the Actor on the Apify platform.
run_input (Dict): The input object of the Actor that you're trying to run.
dataset_mapping_function (Callable): A function that takes a single
dictionary (an Apify dataset item) and converts it to
an instance of the Document class.
build (str, optional): Optionally specifies the actor build to run.
It can be either a build tag or build number.
memory_mbytes (int, optional): Optional memory limit for the run,
in megabytes.
timeout_secs (int, optional): Optional timeout for the run, in seconds.
Returns:
ApifyDatasetLoader: A loader that will fetch the records from the
Actor run's default dataset.
"""
from langchain_community.document_loaders import ApifyDatasetLoader
actor_call = await self.apify_client_async.actor(actor_id).call(
run_input=run_input,
build=build,
memory_mbytes=memory_mbytes,
timeout_secs=timeout_secs,
)
return ApifyDatasetLoader(
dataset_id=actor_call["defaultDatasetId"],
dataset_mapping_function=dataset_mapping_function,
)
def call_actor_task(
self,
task_id: str,
task_input: Dict,
dataset_mapping_function: Callable[[Dict], Document],
*,
build: Optional[str] = None,
memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None,
) -> "ApifyDatasetLoader":
"""Run a saved Actor task on Apify and wait for results to be ready.
Args:
task_id (str): The ID or name of the task on the Apify platform.
task_input (Dict): The input object of the task that you're trying to run.
Overrides the task's saved input.
dataset_mapping_function (Callable): A function that takes a single
dictionary (an Apify dataset item) and converts it to an
instance of the Document class.
build (str, optional): Optionally specifies the actor build to run.
It can be either a build tag or build number.
memory_mbytes (int, optional): Optional memory limit for the run,
in megabytes.
timeout_secs (int, optional): Optional timeout for the run, in seconds.
Returns:
ApifyDatasetLoader: A loader that will fetch the records from the
task run's default dataset.
"""
from langchain_community.document_loaders import ApifyDatasetLoader
task_call = self.apify_client.task(task_id).call(
task_input=task_input,
build=build,
memory_mbytes=memory_mbytes,
timeout_secs=timeout_secs,
)
return ApifyDatasetLoader(
dataset_id=task_call["defaultDatasetId"],
dataset_mapping_function=dataset_mapping_function,
)
async def acall_actor_task(
self,
task_id: str,
task_input: Dict,
dataset_mapping_function: Callable[[Dict], Document],
*,
build: Optional[str] = None,
memory_mbytes: Optional[int] = None,
timeout_secs: Optional[int] = None,
) -> "ApifyDatasetLoader":
"""Run a saved Actor task on Apify and wait for results to be ready.
Args:
task_id (str): The ID or name of the task on the Apify platform.
task_input (Dict): The input object of the task that you're trying to run.
Overrides the task's saved input.
dataset_mapping_function (Callable): A function that takes a single
dictionary (an Apify dataset item) and converts it to an
instance of the Document class.
build (str, optional): Optionally specifies the actor build to run.
It can be either a build tag or build number.
memory_mbytes (int, optional): Optional memory limit for the run,
in megabytes.
timeout_secs (int, optional): Optional timeout for the run, in seconds.
Returns:
ApifyDatasetLoader: A loader that will fetch the records from the
task run's default dataset.
"""
from langchain_community.document_loaders import ApifyDatasetLoader
task_call = await self.apify_client_async.task(task_id).call(
task_input=task_input,
build=build,
memory_mbytes=memory_mbytes,
timeout_secs=timeout_secs,
)
return ApifyDatasetLoader(
dataset_id=task_call["defaultDatasetId"],
dataset_mapping_function=dataset_mapping_function,
)

View File

@@ -0,0 +1,256 @@
# This module contains utility classes and functions for interacting with Arcee API.
# For more information and updates, refer to the Arcee utils page:
# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]
from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
import requests
from langchain_core.retrievers import Document
from pydantic import BaseModel, SecretStr, model_validator
class ArceeRoute(str, Enum):
"""Routes available for the Arcee API as enumerator."""
generate = "models/generate"
retrieve = "models/retrieve"
model_training_status = "models/status/{id_or_name}"
class DALMFilterType(str, Enum):
"""Filter types available for a DALM retrieval as enumerator."""
fuzzy_search = "fuzzy_search"
strict_search = "strict_search"
class DALMFilter(BaseModel):
"""Filters available for a DALM retrieval and generation.
Arguments:
field_name: The field to filter on. Can be 'document' or 'name' to filter
on your document's raw text or title. Any other field will be presumed
to be a metadata field you included when uploading your context data
filter_type: Currently 'fuzzy_search' and 'strict_search' are supported.
'fuzzy_search' means a fuzzy search on the provided field is performed.
The exact strict doesn't need to exist in the document
for this to find a match.
Very useful for scanning a document for some keyword terms.
'strict_search' means that the exact string must appear
in the provided field.
This is NOT an exact eq filter. ie a document with content
"the happy dog crossed the street" will match on a strict_search of
"dog" but won't match on "the dog".
Python equivalent of `return search_string in full_string`.
value: The actual value to search for in the context data/metadata
"""
field_name: str
filter_type: DALMFilterType
value: str
_is_metadata: bool = False
@model_validator(mode="before")
@classmethod
def set_meta(cls, values: Dict) -> Any:
"""document and name are reserved arcee keys. Anything else is metadata"""
values["_is_meta"] = values.get("field_name") not in ["document", "name"]
return values
class ArceeDocumentSource(BaseModel):
"""Source of an Arcee document."""
document: str
name: str
id: str
class ArceeDocument(BaseModel):
"""Arcee document."""
index: str
id: str
score: float
source: ArceeDocumentSource
class ArceeDocumentAdapter:
"""Adapter for Arcee documents"""
@classmethod
def adapt(cls, arcee_document: ArceeDocument) -> Document:
"""Adapts an `ArceeDocument` to a langchain's `Document` object."""
return Document(
page_content=arcee_document.source.document,
metadata={
# arcee document; source metadata
"name": arcee_document.source.name,
"source_id": arcee_document.source.id,
# arcee document metadata
"index": arcee_document.index,
"id": arcee_document.id,
"score": arcee_document.score,
},
)
class ArceeWrapper:
"""Wrapper for Arcee API.
For more details, see: https://www.arcee.ai/
"""
def __init__(
self,
arcee_api_key: Union[str, SecretStr],
arcee_api_url: str,
arcee_api_version: str,
model_kwargs: Optional[Dict[str, Any]],
model_name: str,
):
"""Initialize ArceeWrapper.
Arguments:
arcee_api_key: API key for Arcee API.
arcee_api_url: URL for Arcee API.
arcee_api_version: Version of Arcee API.
model_kwargs: Keyword arguments for Arcee API.
model_name: Name of an Arcee model.
"""
if isinstance(arcee_api_key, str):
arcee_api_key_ = SecretStr(arcee_api_key)
else:
arcee_api_key_ = arcee_api_key
self.arcee_api_key: SecretStr = arcee_api_key_
self.model_kwargs = model_kwargs
self.arcee_api_url = arcee_api_url
self.arcee_api_version = arcee_api_version
try:
route = ArceeRoute.model_training_status.value.format(id_or_name=model_name)
response = self._make_request("get", route)
self.model_id = response.get("model_id")
self.model_training_status = response.get("status")
except Exception as e:
raise ValueError(
f"Error while validating model training status for '{model_name}': {e}"
) from e
def validate_model_training_status(self) -> None:
if self.model_training_status != "training_complete":
raise Exception(
f"Model {self.model_id} is not ready. "
"Please wait for training to complete."
)
def _make_request(
self,
method: Literal["post", "get"],
route: Union[ArceeRoute, str],
body: Optional[Mapping[str, Any]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
) -> dict:
"""Make a request to the Arcee API
Args:
method: The HTTP method to use
route: The route to call
body: The body of the request
params: The query params of the request
headers: The headers of the request
"""
headers = self._make_request_headers(headers=headers)
url = self._make_request_url(route=route)
req_type = getattr(requests, method)
response = req_type(url, json=body, params=params, headers=headers)
if response.status_code not in (200, 201):
raise Exception(f"Failed to make request. Response: {response.text}")
return response.json()
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
headers = headers or {}
if not isinstance(self.arcee_api_key, SecretStr):
raise TypeError(
f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}"
)
api_key = self.arcee_api_key.get_secret_value()
internal_headers = {
"X-Token": api_key,
"Content-Type": "application/json",
}
headers.update(internal_headers)
return headers
def _make_request_url(self, route: Union[ArceeRoute, str]) -> str:
return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}"
def _make_request_body_for_models(
self, prompt: str, **kwargs: Mapping[str, Any]
) -> Mapping[str, Any]:
"""Make the request body for generate/retrieve models endpoint"""
_model_kwargs = self.model_kwargs or {}
_params = {**_model_kwargs, **kwargs}
filters = [DALMFilter(**f) for f in _params.get("filters", [])]
return dict(
model_id=self.model_id,
query=prompt,
size=_params.get("size", 3),
filters=filters,
id=self.model_id,
)
def generate(
self,
prompt: str,
**kwargs: Any,
) -> str:
"""Generate text from Arcee DALM.
Args:
prompt: Prompt to generate text from.
size: The max number of context results to retrieve. Defaults to 3.
(Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
response = self._make_request(
method="post",
route=ArceeRoute.generate.value,
body=self._make_request_body_for_models(
prompt=prompt,
**kwargs,
),
)
return response["text"]
def retrieve(
self,
query: str,
**kwargs: Any,
) -> List[Document]:
"""Retrieve {size} contexts with your retriever for a given query
Args:
query: Query to submit to the model
size: The max number of context results to retrieve. Defaults to 3.
(Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
response = self._make_request(
method="post",
route=ArceeRoute.retrieve.value,
body=self._make_request_body_for_models(
prompt=query,
**kwargs,
),
)
return [
ArceeDocumentAdapter.adapt(ArceeDocument(**doc))
for doc in response["results"]
]

View File

@@ -0,0 +1,255 @@
"""Util that calls Arxiv."""
import logging
import os
import re
from typing import Any, Dict, Iterator, List, Optional
from langchain_core.documents import Document
from pydantic import BaseModel, model_validator
logger = logging.getLogger(__name__)
class ArxivAPIWrapper(BaseModel):
"""Wrapper around ArxivAPI.
To use, you should have the ``arxiv`` python package installed.
https://lukasschwab.me/arxiv.py/index.html
This wrapper will use the Arxiv API to conduct searches and
fetch document summaries. By default, it will return the document summaries
of the top-k results.
If the query is in the form of arxiv identifier
(see https://info.arxiv.org/help/find/index.html), it will return the paper
corresponding to the arxiv identifier.
It limits the Document content by doc_content_chars_max.
Set doc_content_chars_max=None if you don't want to limit the content size.
Attributes:
top_k_results: number of the top-scored document used for the arxiv tool
ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
continue_on_failure (bool): If True, continue loading other URLs on failure.
load_max_docs: a limit to the number of loaded documents
load_all_available_meta:
if True: the `metadata` of the loaded Documents contains all available
meta info (see https://lukasschwab.me/arxiv.py/index.html#Result),
if False: the `metadata` contains only the published date, title,
authors and summary.
doc_content_chars_max: an optional cut limit for the length of a document's
content
Example:
.. code-block:: python
from langchain_community.utilities.arxiv import ArxivAPIWrapper
arxiv = ArxivAPIWrapper(
top_k_results = 3,
ARXIV_MAX_QUERY_LENGTH = 300,
load_max_docs = 3,
load_all_available_meta = False,
doc_content_chars_max = 40000
)
arxiv.run("tree of thought llm")
"""
arxiv_search: Any #: :meta private:
arxiv_exceptions: Any # :meta private:
top_k_results: int = 3
ARXIV_MAX_QUERY_LENGTH: int = 300
continue_on_failure: bool = False
load_max_docs: int = 100
load_all_available_meta: bool = False
doc_content_chars_max: Optional[int] = 4000
def is_arxiv_identifier(self, query: str) -> bool:
"""Check if a query is an arxiv identifier."""
arxiv_identifier_pattern = r"\d{2}(0[1-9]|1[0-2])\.\d{4,5}(v\d+|)|\d{7}.*"
for query_item in query[: self.ARXIV_MAX_QUERY_LENGTH].split():
match_result = re.match(arxiv_identifier_pattern, query_item)
if not match_result:
return False
assert match_result is not None
if not match_result.group(0) == query_item:
return False
return True
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that the python package exists in environment."""
try:
import arxiv
values["arxiv_search"] = arxiv.Search
values["arxiv_exceptions"] = (
arxiv.ArxivError,
arxiv.UnexpectedEmptyPageError,
arxiv.HTTPError,
)
values["arxiv_result"] = arxiv.Result
except ImportError:
raise ImportError(
"Could not import arxiv python package. "
"Please install it with `pip install arxiv`."
)
return values
def _fetch_results(self, query: str) -> Any:
"""Helper function to fetch arxiv results based on query."""
if self.is_arxiv_identifier(query):
return self.arxiv_search(
id_list=query.split(), max_results=self.top_k_results
).results()
return self.arxiv_search(
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
).results()
def get_summaries_as_docs(self, query: str) -> List[Document]:
"""
Performs an arxiv search and returns list of
documents, with summaries as the content.
If an error occurs or no documents found, error text
is returned instead. Wrapper for
https://lukasschwab.me/arxiv.py/index.html#Search
Args:
query: a plaintext search query
"""
try:
results = self._fetch_results(
query
) # Using helper function to fetch results
except self.arxiv_exceptions as ex:
logger.error(f"Arxiv exception: {ex}") # Added error logging
return [Document(page_content=f"Arxiv exception: {ex}")]
docs = [
Document(
page_content=result.summary,
metadata={
"Entry ID": result.entry_id,
"Published": result.updated.date(),
"Title": result.title,
"Authors": ", ".join(a.name for a in result.authors),
},
)
for result in results
]
return docs
def run(self, query: str) -> str:
"""
Performs an arxiv search and A single string
with the publish date, title, authors, and summary
for each article separated by two newlines.
If an error occurs or no documents found, error text
is returned instead. Wrapper for
https://lukasschwab.me/arxiv.py/index.html#Search
Args:
query: a plaintext search query
"""
try:
results = self._fetch_results(
query
) # Using helper function to fetch results
except self.arxiv_exceptions as ex:
logger.error(f"Arxiv exception: {ex}") # Added error logging
return f"Arxiv exception: {ex}"
docs = [
f"Published: {result.updated.date()}\n"
f"Title: {result.title}\n"
f"Authors: {', '.join(a.name for a in result.authors)}\n"
f"Summary: {result.summary}"
for result in results
]
if docs:
return "\n\n".join(docs)[: self.doc_content_chars_max]
else:
return "No good Arxiv Result was found"
def load(self, query: str) -> List[Document]:
"""
Run Arxiv search and get the article texts plus the article meta information.
See https://lukasschwab.me/arxiv.py/index.html#Search
Returns: a list of documents with the document.page_content in text format
Performs an arxiv search, downloads the top k results as PDFs, loads
them as Documents, and returns them in a List.
Args:
query: a plaintext search query
"""
return list(self.lazy_load(query))
def lazy_load(self, query: str) -> Iterator[Document]:
"""
Run Arxiv search and get the article texts plus the article meta information.
See https://lukasschwab.me/arxiv.py/index.html#Search
Returns: documents with the document.page_content in text format
Performs an arxiv search, downloads the top k results as PDFs, loads
them as Documents, and returns them.
Args:
query: a plaintext search query
"""
try:
import fitz
except ImportError:
raise ImportError(
"PyMuPDF package not found, please install it with "
"`pip install pymupdf`"
)
try:
# Remove the ":" and "-" from the query, as they can cause search problems
query = query.replace(":", "").replace("-", "")
results = self._fetch_results(
query
) # Using helper function to fetch results
except self.arxiv_exceptions as ex:
logger.debug("Error on arxiv: %s", ex)
return
for result in results:
try:
doc_file_name: str = result.download_pdf()
with fitz.open(doc_file_name) as doc_file:
text: str = "".join(page.get_text() for page in doc_file)
except (FileNotFoundError, fitz.fitz.FileDataError) as f_ex:
logger.debug(f_ex)
continue
except Exception as e:
if self.continue_on_failure:
logger.error(e)
continue
else:
raise e
if self.load_all_available_meta:
extra_metadata = {
"entry_id": result.entry_id,
"published_first_time": str(result.published.date()),
"comment": result.comment,
"journal_ref": result.journal_ref,
"doi": result.doi,
"primary_category": result.primary_category,
"categories": result.categories,
"links": [link.href for link in result.links],
}
else:
extra_metadata = {}
metadata = {
"Published": str(result.updated.date()),
"Title": result.title,
"Authors": ", ".join(a.name for a in result.authors),
"Summary": result.summary,
**extra_metadata,
}
yield Document(
page_content=text[: self.doc_content_chars_max], metadata=metadata
)
os.remove(doc_file_name)

View File

@@ -0,0 +1,115 @@
"""Util that calls AskNews api."""
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, ConfigDict, model_validator
class AskNewsAPIWrapper(BaseModel):
"""Wrapper for AskNews API."""
asknews_sync: Any = None #: :meta private:
asknews_async: Any = None #: :meta private:
asknews_client_id: Optional[str] = None
"""Client ID for the AskNews API."""
asknews_client_secret: Optional[str] = None
"""Client Secret for the AskNews API."""
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api credentials and python package exists in environment."""
asknews_client_id = get_from_dict_or_env(
values, "asknews_client_id", "ASKNEWS_CLIENT_ID"
)
asknews_client_secret = get_from_dict_or_env(
values, "asknews_client_secret", "ASKNEWS_CLIENT_SECRET"
)
try:
import asknews_sdk
except ImportError:
raise ImportError(
"AskNews python package not found. "
"Please install it with `pip install asknews`."
)
an_sync = asknews_sdk.AskNewsSDK(
client_id=asknews_client_id,
client_secret=asknews_client_secret,
scopes=["news"],
)
an_async = asknews_sdk.AsyncAskNewsSDK(
client_id=asknews_client_id,
client_secret=asknews_client_secret,
scopes=["news"],
)
values["asknews_sync"] = an_sync
values["asknews_async"] = an_async
values["asknews_client_id"] = asknews_client_id
values["asknews_client_secret"] = asknews_client_secret
return values
def search_news(
self, query: str, max_results: int = 10, hours_back: int = 0
) -> str:
"""Search news in AskNews API synchronously."""
if hours_back > 48:
method = "kw"
historical = True
start = int((datetime.now() - timedelta(hours=hours_back)).timestamp())
stop = int(datetime.now().timestamp())
else:
historical = False
method = "nl"
start = None
stop = None
response = self.asknews_sync.news.search_news(
query=query,
n_articles=max_results,
method=method,
historical=historical,
start_timestamp=start,
end_timestamp=stop,
return_type="string",
)
return response.as_string
async def asearch_news(
self, query: str, max_results: int = 10, hours_back: int = 0
) -> str:
"""Search news in AskNews API asynchronously."""
if hours_back > 48:
method = "kw"
historical = True
start = int((datetime.now() - timedelta(hours=hours_back)).timestamp())
stop = int(datetime.now().timestamp())
else:
historical = False
method = "nl"
start = None
stop = None
response = await self.asknews_async.news.search_news(
query=query,
n_articles=max_results,
method=method,
historical=historical,
start_timestamp=start,
end_timestamp=stop,
return_type="string",
)
return response.as_string

View File

@@ -0,0 +1,171 @@
from __future__ import annotations
import asyncio
import inspect
from asyncio import InvalidStateError, Task
from enum import Enum
from typing import TYPE_CHECKING, Awaitable, Optional, Union
if TYPE_CHECKING:
from astrapy.db import (
AstraDB,
AsyncAstraDB,
)
class SetupMode(Enum):
"""Setup mode for AstraDBEnvironment as enumerator."""
SYNC = 1
ASYNC = 2
OFF = 3
class _AstraDBEnvironment:
def __init__(
self,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
) -> None:
self.token = token
self.api_endpoint = api_endpoint
astra_db = astra_db_client
async_astra_db = async_astra_db_client
self.namespace = namespace
try:
from astrapy.db import (
AstraDB,
AsyncAstraDB,
)
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import a recent astrapy python package. "
"Please install it with `pip install --upgrade astrapy`."
)
# Conflicting-arg checks:
if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
)
if token and api_endpoint:
astra_db = AstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=self.namespace,
)
async_astra_db = AsyncAstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=self.namespace,
)
if astra_db:
self.astra_db = astra_db
if async_astra_db:
self.async_astra_db = async_astra_db
else:
self.async_astra_db = AsyncAstraDB(
token=self.astra_db.token,
api_endpoint=self.astra_db.base_url,
api_path=self.astra_db.api_path,
api_version=self.astra_db.api_version,
namespace=self.astra_db.namespace,
)
elif async_astra_db:
self.async_astra_db = async_astra_db
self.astra_db = AstraDB(
token=self.async_astra_db.token,
api_endpoint=self.async_astra_db.base_url,
api_path=self.async_astra_db.api_path,
api_version=self.async_astra_db.api_version,
namespace=self.async_astra_db.namespace,
)
else:
raise ValueError(
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
"'token' and 'api_endpoint'"
)
class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
embedding_dimension: Union[int, Awaitable[int], None] = None,
metric: Optional[str] = None,
) -> None:
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
super().__init__(
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
)
self.collection_name = collection_name
self.collection = AstraDBCollection(
collection_name=collection_name,
astra_db=self.astra_db,
)
self.async_collection = AsyncAstraDBCollection(
collection_name=collection_name,
astra_db=self.async_astra_db,
)
self.async_setup_db_task: Optional[Task] = None
if setup_mode == SetupMode.ASYNC:
async_astra_db = self.async_astra_db
async def _setup_db() -> None:
if pre_delete_collection:
await async_astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
dimension: Optional[int] = await embedding_dimension
else:
dimension = embedding_dimension
await async_astra_db.create_collection(
collection_name, dimension=dimension, metric=metric
)
self.async_setup_db_task = asyncio.create_task(_setup_db())
elif setup_mode == SetupMode.SYNC:
if pre_delete_collection:
self.astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
raise ValueError(
"Cannot use an awaitable embedding_dimension with async_setup "
"set to False"
)
self.astra_db.create_collection(
collection_name,
dimension=embedding_dimension,
metric=metric,
)
def ensure_db_setup(self) -> None:
if self.async_setup_db_task:
try:
self.async_setup_db_task.result()
except InvalidStateError:
raise ValueError(
"Asynchronous setup of the DB not finished. "
"NB: AstraDB components sync methods shouldn't be called from the "
"event loop. Consider using their async equivalents."
)
async def aensure_db_setup(self) -> None:
if self.async_setup_db_task:
await self.async_setup_db_task

View File

@@ -0,0 +1,81 @@
"""Util that calls Lambda."""
import json
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, model_validator
class LambdaWrapper(BaseModel):
"""Wrapper for AWS Lambda SDK.
To use, you should have the ``boto3`` package installed
and a lambda functions built from the AWS Console or
CLI. Set up your AWS credentials with ``aws configure``
Example:
.. code-block:: bash
pip install boto3
aws configure
"""
lambda_client: Any = None #: :meta private:
"""The configured boto3 client"""
function_name: Optional[str] = None
"""The name of your lambda function"""
awslambda_tool_name: Optional[str] = None
"""If passing to an agent as a tool, the tool name"""
awslambda_tool_description: Optional[str] = None
"""If passing to an agent as a tool, the description"""
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that python package exists in environment."""
try:
import boto3
except ImportError:
raise ImportError(
"boto3 is not installed. Please install it with `pip install boto3`"
)
values["lambda_client"] = boto3.client("lambda")
return values
def run(self, query: str) -> str:
"""
Invokes the lambda function and returns the
result.
Args:
query: an input to passed to the lambda
function as the ``body`` of a JSON
object.
"""
res = self.lambda_client.invoke(
FunctionName=self.function_name,
InvocationType="RequestResponse",
Payload=json.dumps({"body": query}),
)
try:
payload_stream = res["Payload"]
payload_string = payload_stream.read().decode("utf-8")
answer = json.loads(payload_string)["body"]
except StopIteration:
return "Failed to parse response from Lambda"
if answer is None or answer == "":
# We don't want to return the assumption alone if answer is empty
return "Request failed."
else:
return f"Result: {answer}"

View File

@@ -0,0 +1,88 @@
"""Util that calls bibtexparser."""
import logging
from typing import Any, Dict, List, Mapping
from pydantic import BaseModel, ConfigDict, model_validator
logger = logging.getLogger(__name__)
OPTIONAL_FIELDS = [
"annotate",
"booktitle",
"editor",
"howpublished",
"journal",
"keywords",
"note",
"organization",
"publisher",
"school",
"series",
"type",
"doi",
"issn",
"isbn",
]
class BibtexparserWrapper(BaseModel):
"""Wrapper around bibtexparser.
To use, you should have the ``bibtexparser`` python package installed.
https://bibtexparser.readthedocs.io/en/master/
This wrapper will use bibtexparser to load a collection of references from
a bibtex file and fetch document summaries.
"""
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that the python package exists in environment."""
try:
import bibtexparser # noqa
except ImportError:
raise ImportError(
"Could not import bibtexparser python package. "
"Please install it with `pip install bibtexparser`."
)
return values
def load_bibtex_entries(self, path: str) -> List[Dict[str, Any]]:
"""Load bibtex entries from the bibtex file at the given path."""
import bibtexparser
with open(path) as file:
entries = bibtexparser.load(file).entries
return entries
def get_metadata(
self, entry: Mapping[str, Any], load_extra: bool = False
) -> Dict[str, Any]:
"""Get metadata for the given entry."""
publication = entry.get("journal") or entry.get("booktitle")
if "url" in entry:
url = entry["url"]
elif "doi" in entry:
url = f"https://doi.org/{entry['doi']}"
else:
url = None
meta = {
"id": entry.get("ID"),
"published_year": entry.get("year"),
"title": entry.get("title"),
"publication": publication,
"authors": entry.get("author"),
"abstract": entry.get("abstract"),
"url": url,
}
if load_extra:
for field in OPTIONAL_FIELDS:
meta[field] = entry.get(field)
return {k: v for k, v in meta.items() if v is not None}

View File

@@ -0,0 +1,117 @@
"""Util that calls Bing Search."""
from typing import Any, Dict, List
import requests
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, ConfigDict, Field, model_validator
# BING_SEARCH_ENDPOINT is the default endpoint for Bing Web Search API.
# Currently There are two web-based Bing Search services available on Azure,
# i.e. Bing Web Search[1] and Bing Custom Search[2]. Compared to Bing Custom Search,
# Both services that provides a wide range of search results, while Bing Custom
# Search requires you to provide an additional custom search instance, `customConfig`.
# Both services are available for BingSearchAPIWrapper.
# History of Azure Bing Search API:
# Before shown in Azure Marketplace as a separate service, Bing Search APIs were
# part of Azure Cognitive Services, the endpoint of which is unique, and the user
# must specify the endpoint when making a request. After transitioning to Azure
# Marketplace, the endpoint is standardized and the user does not need to specify
# the endpoint[3].
# Reference:
# 1. https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
# 2. https://learn.microsoft.com/en-us/bing/search-apis/bing-custom-search/overview
# 3. https://azure.microsoft.com/en-in/updates/bing-search-apis-will-transition-from-azure-cognitive-services-to-azure-marketplace-on-31-october-2023/
DEFAULT_BING_SEARCH_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
class BingSearchAPIWrapper(BaseModel):
"""Wrapper for Bing Web Search API."""
bing_subscription_key: str
bing_search_url: str
k: int = 10
search_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the search request."""
model_config = ConfigDict(
extra="forbid",
)
def _bing_search_results(self, search_term: str, count: int) -> List[dict]:
headers = {"Ocp-Apim-Subscription-Key": self.bing_subscription_key}
params = {
"q": search_term,
"count": count,
"textDecorations": True,
"textFormat": "HTML",
**self.search_kwargs,
}
response = requests.get(
self.bing_search_url,
headers=headers,
params=params,
)
response.raise_for_status()
search_results = response.json()
if "webPages" in search_results:
return search_results["webPages"]["value"]
return []
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key and endpoint exists in environment."""
bing_subscription_key = get_from_dict_or_env(
values, "bing_subscription_key", "BING_SUBSCRIPTION_KEY"
)
values["bing_subscription_key"] = bing_subscription_key
bing_search_url = get_from_dict_or_env(
values,
"bing_search_url",
"BING_SEARCH_URL",
default=DEFAULT_BING_SEARCH_ENDPOINT,
)
values["bing_search_url"] = bing_search_url
return values
def run(self, query: str) -> str:
"""Run query through BingSearch and parse result."""
snippets = []
results = self._bing_search_results(query, count=self.k)
if len(results) == 0:
return "No good Bing Search Result was found"
for result in results:
snippets.append(result["snippet"])
return " ".join(snippets)
def results(self, query: str, num_results: int) -> List[Dict]:
"""Run query through BingSearch and return metadata.
Args:
query: The query to search for.
num_results: The number of results to return.
Returns:
A list of dictionaries with the following keys:
snippet - The description of the result.
title - The title of the result.
link - The link to the result.
"""
metadata_results = []
results = self._bing_search_results(query, count=num_results)
if len(results) == 0:
return [{"Result": "No good Bing Search Result was found"}]
for result in results:
metadata_result = {
"snippet": result["snippet"],
"title": result["name"],
"link": result["url"],
}
metadata_results.append(metadata_result)
return metadata_results

View File

@@ -0,0 +1,83 @@
import json
from typing import List
import requests
from langchain_core.documents import Document
from langchain_core.utils import secret_from_env
from pydantic import BaseModel, Field, SecretStr
class BraveSearchWrapper(BaseModel):
"""Wrapper around the Brave search engine."""
api_key: SecretStr = Field(
default_factory=secret_from_env(["BRAVE_SEARCH_API_KEY"])
)
"""The API key to use for the Brave search engine."""
search_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the search request."""
base_url: str = "https://api.search.brave.com/res/v1/web/search"
"""The base URL for the Brave search engine."""
def run(self, query: str) -> str:
"""Query the Brave search engine and return the results as a JSON string.
Args:
query: The query to search for.
Returns: The results as a JSON string.
"""
web_search_results = self._search_request(query=query)
final_results = [
{
"title": item.get("title"),
"link": item.get("url"),
"snippet": " ".join(
filter(
None, [item.get("description"), *item.get("extra_snippets", [])]
)
),
}
for item in web_search_results
]
return json.dumps(final_results)
def download_documents(self, query: str) -> List[Document]:
"""Query the Brave search engine and return the results as a list of Documents.
Args:
query: The query to search for.
Returns: The results as a list of Documents.
"""
results = self._search_request(query)
return [
Document(
page_content=" ".join(
filter(
None, [item.get("description"), *item.get("extra_snippets", [])]
)
),
metadata={"title": item.get("title"), "link": item.get("url")},
)
for item in results
]
def _search_request(self, query: str) -> List[dict]:
headers = {
"X-Subscription-Token": self.api_key.get_secret_value(),
"Accept": "application/json",
}
req = requests.PreparedRequest()
params = {**self.search_kwargs, **{"q": query, "extra_snippets": True}}
req.prepare_url(self.base_url, params)
if req.url is None:
raise ValueError("prepared url is None, this should not happen")
response = requests.get(req.url, headers=headers)
if not response.ok:
raise Exception(f"HTTP error {response.status_code}")
return response.json().get("web", {}).get("results", [])

View File

@@ -0,0 +1,55 @@
from __future__ import annotations
import asyncio
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
from cassandra.cluster import ResponseFuture, Session
async def wrapped_response_future(
func: Callable[..., ResponseFuture], *args: Any, **kwargs: Any
) -> Any:
"""Wrap a Cassandra response future in an asyncio future.
Args:
func: The Cassandra function to call.
*args: The arguments to pass to the Cassandra function.
**kwargs: The keyword arguments to pass to the Cassandra function.
Returns:
The result of the Cassandra function.
"""
loop = asyncio.get_event_loop()
asyncio_future = loop.create_future()
response_future = func(*args, **kwargs)
def success_handler(_: Any) -> None:
loop.call_soon_threadsafe(asyncio_future.set_result, response_future.result())
def error_handler(exc: BaseException) -> None:
loop.call_soon_threadsafe(asyncio_future.set_exception, exc)
response_future.add_callbacks(success_handler, error_handler)
return await asyncio_future
async def aexecute_cql(session: Session, query: str, **kwargs: Any) -> Any:
"""Execute a CQL query asynchronously.
Args:
session: The Cassandra session to use.
query: The CQL query to execute.
kwargs: Additional keyword arguments to pass to the session execute method.
Returns:
The result of the query.
"""
return await wrapped_response_future(session.execute_async, query, **kwargs)
class SetupMode(Enum):
SYNC = 1
ASYNC = 2
OFF = 3

View File

@@ -0,0 +1,662 @@
"""Apache Cassandra database wrapper."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self
if TYPE_CHECKING:
from cassandra.cluster import ResultSet, Session
IGNORED_KEYSPACES = [
"system",
"system_auth",
"system_distributed",
"system_schema",
"system_traces",
"system_views",
"datastax_sla",
"data_endpoint_auth",
]
class CassandraDatabase:
"""Apache Cassandra® database wrapper."""
def __init__(
self,
session: Optional[Session] = None,
exclude_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
cassio_init_kwargs: Optional[Dict[str, Any]] = None,
):
_session = self._resolve_session(session, cassio_init_kwargs)
if not _session:
raise ValueError("Session not provided and cannot be resolved")
self._session = _session
self._exclude_keyspaces = IGNORED_KEYSPACES
self._exclude_tables = exclude_tables or []
self._include_tables = include_tables or []
def run(
self,
query: str,
fetch: str = "all",
**kwargs: Any,
) -> Union[list, Dict[str, Any], ResultSet]:
"""Execute a CQL query and return the results."""
if fetch == "all":
return self.fetch_all(query, **kwargs)
elif fetch == "one":
return self.fetch_one(query, **kwargs)
elif fetch == "cursor":
return self._fetch(query, **kwargs)
else:
raise ValueError("Fetch parameter must be either 'one', 'all', or 'cursor'")
def _fetch(self, query: str, **kwargs: Any) -> ResultSet:
clean_query = self._validate_cql(query, "SELECT")
return self._session.execute(clean_query, **kwargs)
def fetch_all(self, query: str, **kwargs: Any) -> list:
return list(self._fetch(query, **kwargs))
def fetch_one(self, query: str, **kwargs: Any) -> Dict[str, Any]:
result = self._fetch(query, **kwargs)
return result.one()._asdict() if result else {}
def get_keyspace_tables(self, keyspace: str) -> List[Table]:
"""Get the Table objects for the specified keyspace."""
schema = self._resolve_schema([keyspace])
if keyspace in schema:
return schema[keyspace]
else:
return []
# This is a more basic string building function that doesn't use a query builder
# or prepared statements
# TODO: Refactor to use prepared statements
def get_table_data(
self, keyspace: str, table: str, predicate: str, limit: int
) -> str:
"""Get data from the specified table in the specified keyspace."""
query = f"SELECT * FROM {keyspace}.{table}"
if predicate:
query += f" WHERE {predicate}"
if limit:
query += f" LIMIT {limit}"
query += ";"
result = self.fetch_all(query)
data = "\n".join(str(row) for row in result)
return data
def get_context(self) -> Dict[str, Any]:
"""Return db context that you may want in agent prompt."""
keyspaces = self._fetch_keyspaces()
return {"keyspaces": ", ".join(keyspaces)}
def format_keyspace_to_markdown(
self, keyspace: str, tables: Optional[List[Table]] = None
) -> str:
"""
Generates a markdown representation of the schema for a specific keyspace
by iterating over all tables within that keyspace and calling their
as_markdown method.
Args:
keyspace: The name of the keyspace to generate markdown documentation for.
tables: list of tables in the keyspace; it will be resolved if not provided.
Returns:
A string containing the markdown representation of the specified
keyspace schema.
"""
if not tables:
tables = self.get_keyspace_tables(keyspace)
if tables:
output = f"## Keyspace: {keyspace}\n\n"
if tables:
for table in tables:
output += table.as_markdown(include_keyspace=False, header_level=3)
output += "\n\n"
else:
output += "No tables present in keyspace\n\n"
return output
else:
return ""
def format_schema_to_markdown(self) -> str:
"""
Generates a markdown representation of the schema for all keyspaces and tables
within the CassandraDatabase instance. This method utilizes the
format_keyspace_to_markdown method to create markdown sections for each
keyspace, assembling them into a comprehensive schema document.
Iterates through each keyspace in the database, utilizing
format_keyspace_to_markdown to generate markdown for each keyspace's schema,
including details of its tables. These sections are concatenated to form a
single markdown document that represents the schema of the entire database or
the subset of keyspaces that have been resolved in this instance.
Returns:
A markdown string that documents the schema of all resolved keyspaces and
their tables within this CassandraDatabase instance. This includes keyspace
names, table names, comments, columns, partition keys, clustering keys,
and indexes for each table.
"""
schema = self._resolve_schema()
output = "# Cassandra Database Schema\n\n"
for keyspace, tables in schema.items():
output += f"{self.format_keyspace_to_markdown(keyspace, tables)}\n\n"
return output
def _validate_cql(self, cql: str, type: str = "SELECT") -> str:
"""
Validates a CQL query string for basic formatting and safety checks.
Ensures that `cql` starts with the specified type (e.g., SELECT) and does
not contain content that could indicate CQL injection vulnerabilities.
Args:
cql: The CQL query string to be validated.
type: The expected starting keyword of the query, used to verify
that the query begins with the correct operation type
(e.g., "SELECT", "UPDATE"). Defaults to "SELECT".
Returns:
The trimmed and validated CQL query string without a trailing semicolon.
Raises:
ValueError: If the value of `type` is not supported
DatabaseError: If `cql` is considered unsafe
"""
SUPPORTED_TYPES = ["SELECT"]
if type and type.upper() not in SUPPORTED_TYPES:
raise ValueError(
f"""Unsupported CQL type: {type}. Supported types:
{SUPPORTED_TYPES}"""
)
# Basic sanity checks
cql_trimmed = cql.strip()
if not cql_trimmed.upper().startswith(type.upper()):
raise DatabaseError(f"CQL must start with {type.upper()}.")
# Allow a trailing semicolon, but remove (it is optional with the Python driver)
cql_trimmed = cql_trimmed.rstrip(";")
# Consider content within matching quotes to be "safe"
# Remove single-quoted strings
cql_sanitized = re.sub(r"'.*?'", "", cql_trimmed)
# Remove double-quoted strings
cql_sanitized = re.sub(r'".*?"', "", cql_sanitized)
# Find unsafe content in the remaining CQL
if ";" in cql_sanitized:
raise DatabaseError(
"""Potentially unsafe CQL, as it contains a ; at a
place other than the end or within quotation marks."""
)
# The trimmed query, before modifications
return cql_trimmed
def _fetch_keyspaces(self, keyspaces: Optional[List[str]] = None) -> List[str]:
"""
Fetches a list of keyspace names from the Cassandra database. The list can be
filtered by a provided list of keyspace names or by excluding predefined
keyspaces.
Args:
keyspaces: A list of keyspace names to specifically include.
If provided and not empty, the method returns only the keyspaces
present in this list.
If not provided or empty, the method returns all keyspaces except those
specified in the _exclude_keyspaces attribute.
Returns:
A list of keyspace names according to the filtering criteria.
"""
all_keyspaces = self.fetch_all(
"SELECT keyspace_name FROM system_schema.keyspaces"
)
# Filtering keyspaces based on 'keyspace_list' and '_exclude_keyspaces'
filtered_keyspaces = []
for ks in all_keyspaces:
if not isinstance(ks, Dict):
continue # Skip if the row is not a dictionary.
keyspace_name = ks["keyspace_name"]
if keyspaces and keyspace_name in keyspaces:
filtered_keyspaces.append(keyspace_name)
elif not keyspaces and keyspace_name not in self._exclude_keyspaces:
filtered_keyspaces.append(keyspace_name)
return filtered_keyspaces
def _format_keyspace_query(self, query: str, keyspaces: List[str]) -> str:
# Construct IN clause for CQL query
keyspace_in_clause = ", ".join([f"'{ks}'" for ks in keyspaces])
return f"""{query} WHERE keyspace_name IN ({keyspace_in_clause})"""
def _fetch_tables_data(self, keyspaces: List[str]) -> list:
"""Fetches tables schema data, filtered by a list of keyspaces.
This method allows for efficiently fetching schema information for multiple
keyspaces in a single operation, enabling applications to programmatically
analyze or document the database schema.
Args:
keyspaces: A list of keyspace names from which to fetch tables schema data.
Returns:
Dictionaries of table details (keyspace name, table name, and comment).
"""
tables_query = self._format_keyspace_query(
"SELECT keyspace_name, table_name, comment FROM system_schema.tables",
keyspaces,
)
return self.fetch_all(tables_query)
def _fetch_columns_data(self, keyspaces: List[str]) -> list:
"""Fetches columns schema data, filtered by a list of keyspaces.
This method allows for efficiently fetching schema information for multiple
keyspaces in a single operation, enabling applications to programmatically
analyze or document the database schema.
Args:
keyspaces: A list of keyspace names from which to fetch tables schema data.
Returns:
Dictionaries of column details (keyspace name, table name, column name,
type, kind, and position).
"""
tables_query = self._format_keyspace_query(
"""
SELECT keyspace_name, table_name, column_name, type, kind,
clustering_order, position
FROM system_schema.columns
""",
keyspaces,
)
return self.fetch_all(tables_query)
def _fetch_indexes_data(self, keyspaces: List[str]) -> list:
"""Fetches indexes schema data, filtered by a list of keyspaces.
This method allows for efficiently fetching schema information for multiple
keyspaces in a single operation, enabling applications to programmatically
analyze or document the database schema.
Args:
keyspaces: A list of keyspace names from which to fetch tables schema data.
Returns:
Dictionaries of index details (keyspace name, table name, index name, kind,
and options).
"""
tables_query = self._format_keyspace_query(
"""
SELECT keyspace_name, table_name, index_name,
kind, options
FROM system_schema.indexes
""",
keyspaces,
)
return self.fetch_all(tables_query)
def _resolve_schema(
self, keyspaces: Optional[List[str]] = None
) -> Dict[str, List[Table]]:
"""
Efficiently fetches and organizes Cassandra table schema information,
such as comments, columns, and indexes, into a dictionary mapping keyspace
names to lists of Table objects.
Args:
keyspaces: An optional list of keyspace names from which to fetch tables
schema data.
Returns:
A dictionary with keyspace names as keys and lists of Table objects as
values, where each Table object is populated with schema details
appropriate for its keyspace and table name.
"""
if not keyspaces:
keyspaces = self._fetch_keyspaces()
tables_data = self._fetch_tables_data(keyspaces)
columns_data = self._fetch_columns_data(keyspaces)
indexes_data = self._fetch_indexes_data(keyspaces)
keyspace_dict: dict = {}
for table_data in tables_data:
keyspace = table_data.keyspace_name
table_name = table_data.table_name
comment = table_data.comment
if self._include_tables and table_name not in self._include_tables:
continue
if self._exclude_tables and table_name in self._exclude_tables:
continue
# Filter columns and indexes for this table
table_columns = [
(c.column_name, c.type)
for c in columns_data
if c.keyspace_name == keyspace and c.table_name == table_name
]
partition_keys = [
c.column_name
for c in columns_data
if c.kind == "partition_key"
and c.keyspace_name == keyspace
and c.table_name == table_name
]
clustering_keys = [
(c.column_name, c.clustering_order)
for c in columns_data
if c.kind == "clustering"
and c.keyspace_name == keyspace
and c.table_name == table_name
]
table_indexes = [
(c.index_name, c.kind, c.options)
for c in indexes_data
if c.keyspace_name == keyspace and c.table_name == table_name
]
table_obj = Table(
keyspace=keyspace,
table_name=table_name,
comment=comment,
columns=table_columns,
partition=partition_keys,
clustering=clustering_keys,
indexes=table_indexes,
)
if keyspace not in keyspace_dict:
keyspace_dict[keyspace] = []
keyspace_dict[keyspace].append(table_obj)
return keyspace_dict
@staticmethod
def _resolve_session(
session: Optional[Session] = None,
cassio_init_kwargs: Optional[Dict[str, Any]] = None,
) -> Optional[Session]:
"""
Attempts to resolve and return a Session object for use in database operations.
This function follows a specific order of precedence to determine the
appropriate session to use:
1. `session` parameter if given,
2. Existing `cassio` session,
3. A new `cassio` session derived from `cassio_init_kwargs`,
4. `None`
Args:
session: An optional session to use directly.
cassio_init_kwargs: An optional dictionary of keyword arguments to `cassio`.
Returns:
The resolved session object if successful, or `None` if the session
cannot be resolved.
Raises:
ValueError: If `cassio_init_kwargs` is provided but is not a dictionary of
keyword arguments.
"""
# Prefer given session
if session:
return session
# If a session is not provided, create one using cassio if available
# dynamically import cassio to avoid circular imports
try:
import cassio.config
except ImportError:
raise ValueError(
"cassio package not found, please install with `pip install cassio`"
)
# Use pre-existing session on cassio
s = cassio.config.resolve_session()
if s:
return s
# Try to init and return cassio session
if cassio_init_kwargs:
if isinstance(cassio_init_kwargs, dict):
cassio.init(**cassio_init_kwargs)
s = cassio.config.check_resolve_session()
return s
else:
raise ValueError("cassio_init_kwargs must be a keyword dictionary")
# return None if we're not able to resolve
return None
class DatabaseError(Exception):
"""Exception raised for errors in the database schema.
Attributes:
message -- explanation of the error
"""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class Table(BaseModel):
keyspace: str
"""The keyspace in which the table exists."""
table_name: str
"""The name of the table."""
comment: Optional[str] = None
"""The comment associated with the table."""
columns: List[Tuple[str, str]] = Field(default_factory=list)
partition: List[str] = Field(default_factory=list)
clustering: List[Tuple[str, str]] = Field(default_factory=list)
indexes: List[Tuple[str, str, str]] = Field(default_factory=list)
model_config = ConfigDict(
frozen=True,
)
@model_validator(mode="after")
def check_required_fields(self) -> Self:
if not self.columns:
raise ValueError("non-empty column list for must be provided")
if not self.partition:
raise ValueError("non-empty partition list must be provided")
return self
@classmethod
def from_database(
cls, keyspace: str, table_name: str, db: CassandraDatabase
) -> Table:
columns, partition, clustering = cls._resolve_columns(keyspace, table_name, db)
return cls(
keyspace=keyspace,
table_name=table_name,
comment=cls._resolve_comment(keyspace, table_name, db),
columns=columns,
partition=partition,
clustering=clustering,
indexes=cls._resolve_indexes(keyspace, table_name, db),
)
def as_markdown(
self, include_keyspace: bool = True, header_level: Optional[int] = None
) -> str:
"""
Generates a Markdown representation of the Cassandra table schema, allowing for
customizable header levels for the table name section.
Args:
include_keyspace: If True, includes the keyspace in the output.
Defaults to True.
header_level: Specifies the markdown header level for the table name.
If None, the table name is included without a header.
Defaults to None (no header level).
Returns:
A string in Markdown format detailing the table name
(with optional header level), keyspace (optional), comment, columns,
partition keys, clustering keys (with optional clustering order),
and indexes.
"""
output = ""
if header_level is not None:
output += f"{'#' * header_level} "
output += f"Table Name: {self.table_name}\n"
if include_keyspace:
output += f"- Keyspace: {self.keyspace}\n"
if self.comment:
output += f"- Comment: {self.comment}\n"
output += "- Columns\n"
for column, type in self.columns:
output += f" - {column} ({type})\n"
output += f"- Partition Keys: ({', '.join(self.partition)})\n"
output += "- Clustering Keys: "
if self.clustering:
cluster_list = []
for column, clustering_order in self.clustering:
if clustering_order.lower() == "none":
cluster_list.append(column)
else:
cluster_list.append(f"{column} {clustering_order}")
output += f"({', '.join(cluster_list)})\n"
if self.indexes:
output += "- Indexes\n"
for name, kind, options in self.indexes:
output += f" - {name} : kind={kind}, options={options}\n"
return output
@staticmethod
def _resolve_comment(
keyspace: str, table_name: str, db: CassandraDatabase
) -> Optional[str]:
result = db.run(
f"""SELECT comment
FROM system_schema.tables
WHERE keyspace_name = '{keyspace}'
AND table_name = '{table_name}';""",
fetch="one",
)
if isinstance(result, dict):
comment = result.get("comment")
if comment:
return comment
else:
return None # Default comment if none is found
else:
raise ValueError(
f"""Unexpected result type from db.run:
{type(result).__name__}"""
)
@staticmethod
def _resolve_columns(
keyspace: str, table_name: str, db: CassandraDatabase
) -> Tuple[List[Tuple[str, str]], List[str], List[Tuple[str, str]]]:
columns = []
partition_info = []
cluster_info = []
results = db.run(
f"""SELECT column_name, type, kind, clustering_order, position
FROM system_schema.columns
WHERE keyspace_name = '{keyspace}'
AND table_name = '{table_name}';"""
)
# Type check to ensure 'results' is a sequence of dictionaries.
if not isinstance(results, Sequence):
raise TypeError("Expected a sequence of dictionaries from 'run' method.")
for row in results:
if not isinstance(row, Dict):
continue # Skip if the row is not a dictionary.
columns.append((row["column_name"], row["type"]))
if row["kind"] == "partition_key":
partition_info.append((row["column_name"], row["position"]))
elif row["kind"] == "clustering":
cluster_info.append(
(
row["column_name"],
row["clustering_order"],
row["position"],
)
)
partition = [
column_name for column_name, _ in sorted(partition_info, key=lambda x: x[1])
]
cluster = [
(column_name, clustering_order)
for column_name, clustering_order, _ in sorted(
cluster_info, key=lambda x: x[2]
)
]
return columns, partition, cluster
@staticmethod
def _resolve_indexes(
keyspace: str, table_name: str, db: CassandraDatabase
) -> List[Tuple[str, str, str]]:
indexes = []
results = db.run(
f"""SELECT index_name, kind, options
FROM system_schema.indexes
WHERE keyspace_name = '{keyspace}'
AND table_name = '{table_name}';"""
)
# Type check to ensure 'results' is a sequence of dictionaries
if not isinstance(results, Sequence):
raise TypeError("Expected a sequence of dictionaries from 'run' method.")
for row in results:
if not isinstance(row, Dict):
continue # Skip if the row is not a dictionary.
# Convert 'options' to string if it's not already,
# assuming it's JSON-like and needs conversion
index_options = row["options"]
if not isinstance(index_options, str):
# Assuming index_options needs to be serialized or simply converted
index_options = str(index_options)
indexes.append((row["index_name"], row["kind"], index_options))
return indexes

View File

@@ -0,0 +1,626 @@
"""Util that calls clickup."""
import json
import warnings
from dataclasses import asdict, dataclass, fields
from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union
import requests
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, ConfigDict, model_validator
DEFAULT_URL = "https://api.clickup.com/api/v2"
@dataclass
class Component:
"""Base class for all components."""
@classmethod
def from_data(cls, data: Dict[str, Any]) -> "Component":
raise NotImplementedError()
@dataclass
class Task(Component):
"""Class for a task."""
id: int
name: str
text_content: str
description: str
status: str
creator_id: int
creator_username: str
creator_email: str
assignees: List[Dict[str, Any]]
watchers: List[Dict[str, Any]]
priority: Optional[str]
due_date: Optional[str]
start_date: Optional[str]
points: int
team_id: int
project_id: int
@classmethod
def from_data(cls, data: Dict[str, Any]) -> "Task":
priority = None if data["priority"] is None else data["priority"]["priority"]
return cls(
id=data["id"],
name=data["name"],
text_content=data["text_content"],
description=data["description"],
status=data["status"]["status"],
creator_id=data["creator"]["id"],
creator_username=data["creator"]["username"],
creator_email=data["creator"]["email"],
assignees=data["assignees"],
watchers=data["watchers"],
priority=priority,
due_date=data["due_date"],
start_date=data["start_date"],
points=data["points"],
team_id=data["team_id"],
project_id=data["project"]["id"],
)
@dataclass
class CUList(Component):
"""Component class for a list."""
folder_id: float
name: str
content: Optional[str] = None
due_date: Optional[int] = None
due_date_time: Optional[bool] = None
priority: Optional[int] = None
assignee: Optional[int] = None
status: Optional[str] = None
@classmethod
def from_data(cls, data: dict) -> "CUList":
return cls(
folder_id=data["folder_id"],
name=data["name"],
content=data.get("content"),
due_date=data.get("due_date"),
due_date_time=data.get("due_date_time"),
priority=data.get("priority"),
assignee=data.get("assignee"),
status=data.get("status"),
)
@dataclass
class Member(Component):
"""Component class for a member."""
id: int
username: str
email: str
initials: str
@classmethod
def from_data(cls, data: Dict) -> "Member":
return cls(
id=data["user"]["id"],
username=data["user"]["username"],
email=data["user"]["email"],
initials=data["user"]["initials"],
)
@dataclass
class Team(Component):
"""Component class for a team."""
id: int
name: str
members: List[Member]
@classmethod
def from_data(cls, data: Dict) -> "Team":
members = [Member.from_data(member_data) for member_data in data["members"]]
return cls(id=data["id"], name=data["name"], members=members)
@dataclass
class Space(Component):
"""Component class for a space."""
id: int
name: str
private: bool
enabled_features: Dict[str, Any]
@classmethod
def from_data(cls, data: Dict[str, Any]) -> "Space":
space_data = data["spaces"][0]
enabled_features = {
feature: value
for feature, value in space_data["features"].items()
if value["enabled"]
}
return cls(
id=space_data["id"],
name=space_data["name"],
private=space_data["private"],
enabled_features=enabled_features,
)
def parse_dict_through_component(
data: dict, component: Type[Component], fault_tolerant: bool = False
) -> Dict:
"""Parse a dictionary by creating
a component and then turning it back into a dictionary.
This helps with two things
1. Extract and format data from a dictionary according to schema
2. Provide a central place to do this in a fault-tolerant way
"""
try:
return asdict(component.from_data(data))
except Exception as e:
if fault_tolerant:
warning_str = f"""Error encountered while trying to parse
{str(data)}: {str(e)}\n Falling back to returning input data."""
warnings.warn(warning_str)
return data
else:
raise e
def extract_dict_elements_from_component_fields(
data: dict, component: Type[Component]
) -> dict:
"""Extract elements from a dictionary.
Args:
data: The dictionary to extract elements from.
component: The component to extract elements from.
Returns:
`dict` containing the elements from the input dictionary that are also in the
component.
"""
output = {}
for attribute in fields(component):
if attribute.name in data:
output[attribute.name] = data[attribute.name]
return output
def load_query(
query: str, fault_tolerant: bool = False
) -> Tuple[Optional[Dict], Optional[str]]:
"""Parse a JSON string and return the parsed object.
If parsing fails, returns an error message.
:param query: The JSON string to parse.
:return: A tuple containing the parsed object or None and an error message or None.
Exceptions:
json.JSONDecodeError: If the input is not a valid JSON string.
"""
try:
return json.loads(query), None
except json.JSONDecodeError as e:
if fault_tolerant:
return (
None,
f"""Input must be a valid JSON. Got the following error: {str(e)}.
"Please reformat and try again.""",
)
else:
raise e
def fetch_first_id(data: dict, key: str) -> Optional[int]:
"""Fetch the first id from a dictionary."""
if key in data and len(data[key]) > 0:
if len(data[key]) > 1:
warnings.warn(f"Found multiple {key}: {data[key]}. Defaulting to first.")
return data[key][0]["id"]
return None
def fetch_data(url: str, access_token: str, query: Optional[dict] = None) -> dict:
"""Fetch data from a URL."""
headers = {"Authorization": access_token}
response = requests.get(url, headers=headers, params=query)
response.raise_for_status()
return response.json()
def fetch_team_id(access_token: str) -> Optional[int]:
"""Fetch the team id."""
url = f"{DEFAULT_URL}/team"
data = fetch_data(url, access_token)
return fetch_first_id(data, "teams")
def fetch_space_id(team_id: int, access_token: str) -> Optional[int]:
"""Fetch the space id."""
url = f"{DEFAULT_URL}/team/{team_id}/space"
data = fetch_data(url, access_token, query={"archived": "false"})
return fetch_first_id(data, "spaces")
def fetch_folder_id(space_id: int, access_token: str) -> Optional[int]:
"""Fetch the folder id."""
url = f"{DEFAULT_URL}/space/{space_id}/folder"
data = fetch_data(url, access_token, query={"archived": "false"})
return fetch_first_id(data, "folders")
def fetch_list_id(space_id: int, folder_id: int, access_token: str) -> Optional[int]:
"""Fetch the list id."""
if folder_id:
url = f"{DEFAULT_URL}/folder/{folder_id}/list"
else:
url = f"{DEFAULT_URL}/space/{space_id}/list"
data = fetch_data(url, access_token, query={"archived": "false"})
# The structure to fetch list id differs based if its folderless
if folder_id and "id" in data:
return data["id"]
else:
return fetch_first_id(data, "lists")
class ClickupAPIWrapper(BaseModel):
"""Wrapper for Clickup API."""
access_token: Optional[str] = None
team_id: Optional[str] = None
space_id: Optional[str] = None
folder_id: Optional[str] = None
list_id: Optional[str] = None
model_config = ConfigDict(
extra="forbid",
)
@classmethod
def get_access_code_url(
cls, oauth_client_id: str, redirect_uri: str = "https://google.com"
) -> str:
"""Get the URL to get an access code."""
url = f"https://app.clickup.com/api?client_id={oauth_client_id}"
return f"{url}&redirect_uri={redirect_uri}"
@classmethod
def get_access_token(
cls, oauth_client_id: str, oauth_client_secret: str, code: str
) -> Optional[str]:
"""Get the access token."""
url = f"{DEFAULT_URL}/oauth/token"
params = {
"client_id": oauth_client_id,
"client_secret": oauth_client_secret,
"code": code,
}
response = requests.post(url, params=params)
data = response.json()
if "access_token" not in data:
print(f"Error: {data}") # noqa: T201
if "ECODE" in data and data["ECODE"] == "OAUTH_014":
url = ClickupAPIWrapper.get_access_code_url(oauth_client_id)
print( # noqa: T201
"You already used this code once. Generate a new one.",
f"Our best guess for the url to get a new code is:\n{url}",
)
return None
return data["access_token"]
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key and python package exists in environment."""
values["access_token"] = get_from_dict_or_env(
values, "access_token", "CLICKUP_ACCESS_TOKEN"
)
values["team_id"] = fetch_team_id(values["access_token"])
values["space_id"] = fetch_space_id(values["team_id"], values["access_token"])
values["folder_id"] = fetch_folder_id(
values["space_id"], values["access_token"]
)
values["list_id"] = fetch_list_id(
values["space_id"], values["folder_id"], values["access_token"]
)
return values
def attempt_parse_teams(self, input_dict: dict) -> Dict[str, List[dict]]:
"""Parse appropriate content from the list of teams."""
parsed_teams: Dict[str, List[dict]] = {"teams": []}
for team in input_dict["teams"]:
try:
team = parse_dict_through_component(team, Team, fault_tolerant=False)
parsed_teams["teams"].append(team)
except Exception as e:
warnings.warn(f"Error parsing a team {e}")
return parsed_teams
def get_headers(
self,
) -> Mapping[str, Union[str, bytes]]:
"""Get the headers for the request."""
if not isinstance(self.access_token, str):
raise TypeError(f"Access Token: {self.access_token}, must be str.")
headers = {
"Authorization": str(self.access_token),
"Content-Type": "application/json",
}
return headers
def get_default_params(self) -> Dict:
return {"archived": "false"}
def get_authorized_teams(self) -> Dict[Any, Any]:
"""Get all teams for the user."""
url = f"{DEFAULT_URL}/team"
response = requests.get(url, headers=self.get_headers())
data = response.json()
parsed_teams = self.attempt_parse_teams(data)
return parsed_teams
def get_folders(self) -> Dict:
"""
Get all the folders for the team.
"""
url = f"{DEFAULT_URL}/team/" + str(self.team_id) + "/space"
params = self.get_default_params()
response = requests.get(url, headers=self.get_headers(), params=params)
return {"response": response}
def get_task(self, query: str, fault_tolerant: bool = True) -> Dict:
"""
Retrieve a specific task.
"""
params, error = load_query(query, fault_tolerant=True)
if params is None:
return {"Error": error}
url = f"{DEFAULT_URL}/task/{params['task_id']}"
params = {
"custom_task_ids": "true",
"team_id": self.team_id,
"include_subtasks": "true",
}
response = requests.get(url, headers=self.get_headers(), params=params)
data = response.json()
parsed_task = parse_dict_through_component(
data, Task, fault_tolerant=fault_tolerant
)
return parsed_task
def get_lists(self) -> Dict:
"""
Get all available lists.
"""
url = f"{DEFAULT_URL}/folder/{self.folder_id}/list"
params = self.get_default_params()
response = requests.get(url, headers=self.get_headers(), params=params)
return {"response": response}
def query_tasks(self, query: str) -> Dict:
"""
Query tasks that match certain fields
"""
params, error = load_query(query, fault_tolerant=True)
if params is None:
return {"Error": error}
url = f"{DEFAULT_URL}/list/{params['list_id']}/task"
params = self.get_default_params()
response = requests.get(url, headers=self.get_headers(), params=params)
return {"response": response}
def get_spaces(self) -> Dict:
"""
Get all spaces for the team.
"""
url = f"{DEFAULT_URL}/team/{self.team_id}/space"
response = requests.get(
url, headers=self.get_headers(), params=self.get_default_params()
)
data = response.json()
parsed_spaces = parse_dict_through_component(data, Space, fault_tolerant=True)
return parsed_spaces
def get_task_attribute(self, query: str) -> Dict:
"""
Update an attribute of a specified task.
"""
task = self.get_task(query, fault_tolerant=True)
params, error = load_query(query, fault_tolerant=True)
if not isinstance(params, dict):
return {"Error": error}
if params["attribute_name"] not in task:
return {
"Error": f"""attribute_name = {params["attribute_name"]} was not
found in task keys {task.keys()}. Please call again with one of the key names."""
}
return {params["attribute_name"]: task[params["attribute_name"]]}
def update_task(self, query: str) -> Dict:
"""
Update an attribute of a specified task.
"""
query_dict, error = load_query(query, fault_tolerant=True)
if query_dict is None:
return {"Error": error}
url = f"{DEFAULT_URL}/task/{query_dict['task_id']}"
params = {
"custom_task_ids": "true",
"team_id": self.team_id,
"include_subtasks": "true",
}
headers = self.get_headers()
payload = {query_dict["attribute_name"]: query_dict["value"]}
response = requests.put(url, headers=headers, params=params, json=payload)
return {"response": response}
def update_task_assignees(self, query: str) -> Dict:
"""
Add or remove assignees of a specified task.
"""
query_dict, error = load_query(query, fault_tolerant=True)
if query_dict is None:
return {"Error": error}
for user in query_dict["users"]:
if not isinstance(user, int):
return {
"Error": f"""All users must be integers, not strings!
"Got user {user} if type {type(user)}"""
}
url = f"{DEFAULT_URL}/task/{query_dict['task_id']}"
headers = self.get_headers()
if query_dict["operation"] == "add":
assigne_payload = {"add": query_dict["users"], "rem": []}
elif query_dict["operation"] == "rem":
assigne_payload = {"add": [], "rem": query_dict["users"]}
else:
raise ValueError(
f"Invalid operation ({query_dict['operation']}). ",
"Valid options ['add', 'rem'].",
)
params = {
"custom_task_ids": "true",
"team_id": self.team_id,
"include_subtasks": "true",
}
payload = {"assignees": assigne_payload}
response = requests.put(url, headers=headers, params=params, json=payload)
return {"response": response}
def create_task(self, query: str) -> Dict:
"""
Creates a new task.
"""
query_dict, error = load_query(query, fault_tolerant=True)
if query_dict is None:
return {"Error": error}
list_id = self.list_id
url = f"{DEFAULT_URL}/list/{list_id}/task"
params = {"custom_task_ids": "true", "team_id": self.team_id}
payload = extract_dict_elements_from_component_fields(query_dict, Task)
headers = self.get_headers()
response = requests.post(url, json=payload, headers=headers, params=params)
data: Dict = response.json()
return parse_dict_through_component(data, Task, fault_tolerant=True)
def create_list(self, query: str) -> Dict:
"""
Creates a new list.
"""
query_dict, error = load_query(query, fault_tolerant=True)
if query_dict is None:
return {"Error": error}
# Default to using folder as location if it exists.
# If not, fall back to using the space.
location = self.folder_id if self.folder_id else self.space_id
url = f"{DEFAULT_URL}/folder/{location}/list"
payload = extract_dict_elements_from_component_fields(query_dict, Task)
headers = self.get_headers()
response = requests.post(url, json=payload, headers=headers)
data = response.json()
parsed_list = parse_dict_through_component(data, CUList, fault_tolerant=True)
# set list id to new list
if "id" in parsed_list:
self.list_id = parsed_list["id"]
return parsed_list
def create_folder(self, query: str) -> Dict:
"""
Creates a new folder.
"""
query_dict, error = load_query(query, fault_tolerant=True)
if query_dict is None:
return {"Error": error}
space_id = self.space_id
url = f"{DEFAULT_URL}/space/{space_id}/folder"
payload = {
"name": query_dict["name"],
}
headers = self.get_headers()
response = requests.post(url, json=payload, headers=headers)
data = response.json()
if "id" in data:
self.list_id = data["id"]
return data
def run(self, mode: str, query: str) -> str:
"""Run the API."""
if mode == "get_task":
output = self.get_task(query)
elif mode == "get_task_attribute":
output = self.get_task_attribute(query)
elif mode == "get_teams":
output = self.get_authorized_teams()
elif mode == "create_task":
output = self.create_task(query)
elif mode == "create_list":
output = self.create_list(query)
elif mode == "create_folder":
output = self.create_folder(query)
elif mode == "get_lists":
output = self.get_lists()
elif mode == "get_folders":
output = self.get_folders()
elif mode == "get_spaces":
output = self.get_spaces()
elif mode == "update_task":
output = self.update_task(query)
elif mode == "update_task_assignees":
output = self.update_task_assignees(query)
else:
output = {"ModeError": f"Got unexpected mode {mode}."}
try:
return json.dumps(output)
except Exception:
return str(output)

View File

@@ -0,0 +1,160 @@
"""Utility that calls OpenAI's Dall-E Image Generator."""
import logging
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from langchain_core.utils import (
from_env,
get_pydantic_field_names,
secret_from_env,
)
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
from langchain_community.utils.openai import is_openai_v1
logger = logging.getLogger(__name__)
class DallEAPIWrapper(BaseModel):
"""Wrapper for OpenAI's DALL-E Image Generator.
https://platform.openai.com/docs/guides/images/generations?context=node
Usage instructions:
1. `pip install openai`
2. save your OPENAI_API_KEY in an environment variable
"""
client: Any = None #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="dall-e-2", alias="model")
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
openai_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env(
"OPENAI_API_KEY",
default=None,
),
)
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
)
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
openai_organization: Optional[str] = Field(
alias="organization",
default_factory=from_env(
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
),
)
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
# to support explicit proxy for OpenAI
openai_proxy: str = Field(default_factory=from_env("OPENAI_PROXY", default=""))
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
n: int = 1
"""Number of images to generate"""
size: str = "1024x1024"
"""Size of image to generate"""
separator: str = "\n"
"""Separator to use when multiple URLs are returned."""
quality: Optional[str] = None
"""Quality of the image that will be generated"""
max_retries: int = 2
"""Maximum number of retries to make when generating."""
default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None
"""Optional httpx.Client."""
model_config = ConfigDict(extra="forbid", protected_namespaces=())
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if is_openai_v1():
client_params = {
"api_key": self.openai_api_key.get_secret_value()
if self.openai_api_key
else None,
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
"http_client": self.http_client,
}
if not self.client:
self.client = openai.OpenAI(**client_params).images
if not self.async_client:
self.async_client = openai.AsyncOpenAI(**client_params).images
elif not self.client:
self.client = openai.Image
else:
pass
return self
def run(self, query: str) -> str:
"""Run query through OpenAI and parse result."""
if is_openai_v1():
kwargs = {
"prompt": query,
"n": self.n,
"size": self.size,
"model": self.model_name,
}
if self.quality is not None:
kwargs["quality"] = self.quality
response = self.client.generate(**kwargs)
image_urls = self.separator.join([item.url for item in response.data])
else:
response = self.client.create(
prompt=query, n=self.n, size=self.size, model=self.model_name
)
image_urls = self.separator.join([item["url"] for item in response["data"]])
return image_urls if image_urls else "No image was generated"

View File

@@ -0,0 +1,195 @@
import base64
from typing import Any, Dict, Optional
from urllib.parse import quote
import aiohttp
import requests
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, ConfigDict, Field, model_validator
class DataForSeoAPIWrapper(BaseModel):
"""Wrapper around the DataForSeo API."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
default_params: dict = Field(
default={
"location_name": "United States",
"language_code": "en",
"depth": 10,
"se_name": "google",
"se_type": "organic",
}
)
"""Default parameters to use for the DataForSEO SERP API."""
params: dict = Field(default={})
"""Additional parameters to pass to the DataForSEO SERP API."""
api_login: Optional[str] = None
"""The API login to use for the DataForSEO SERP API."""
api_password: Optional[str] = None
"""The API password to use for the DataForSEO SERP API."""
json_result_types: Optional[list] = None
"""The JSON result types."""
json_result_fields: Optional[list] = None
"""The JSON result fields."""
top_count: Optional[int] = None
"""The number of top results to return."""
aiosession: Optional[aiohttp.ClientSession] = None
"""The aiohttp session to use for the DataForSEO SERP API."""
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that login and password exists in environment."""
login = get_from_dict_or_env(values, "api_login", "DATAFORSEO_LOGIN")
password = get_from_dict_or_env(values, "api_password", "DATAFORSEO_PASSWORD")
values["api_login"] = login
values["api_password"] = password
return values
async def arun(self, url: str) -> str:
"""Run request to DataForSEO SERP API and parse result async."""
return self._process_response(await self._aresponse_json(url))
def run(self, url: str) -> str:
"""Run request to DataForSEO SERP API and parse result async."""
return self._process_response(self._response_json(url))
def results(self, url: str) -> list:
res = self._response_json(url)
return self._filter_results(res)
async def aresults(self, url: str) -> list:
res = await self._aresponse_json(url)
return self._filter_results(res)
def _prepare_request(self, keyword: str) -> dict:
"""Prepare the request details for the DataForSEO SERP API."""
if self.api_login is None or self.api_password is None:
raise ValueError("api_login or api_password is not provided")
cred = base64.b64encode(
f"{self.api_login}:{self.api_password}".encode("utf-8")
).decode("utf-8")
headers = {"Authorization": f"Basic {cred}", "Content-Type": "application/json"}
obj = {"keyword": quote(keyword)}
obj = {**obj, **self.default_params, **self.params}
data = [obj]
_url = (
f"https://api.dataforseo.com/v3/serp/{obj['se_name']}"
f"/{obj['se_type']}/live/advanced"
)
return {
"url": _url,
"headers": headers,
"data": data,
}
def _check_response(self, response: dict) -> dict:
"""Check the response from the DataForSEO SERP API for errors."""
if response.get("status_code") != 20000:
raise ValueError(
f"Got error from DataForSEO SERP API: {response.get('status_message')}"
)
return response
def _response_json(self, url: str) -> dict:
"""Use requests to run request to DataForSEO SERP API and return results."""
request_details = self._prepare_request(url)
response = requests.post(
request_details["url"],
headers=request_details["headers"],
json=request_details["data"],
)
response.raise_for_status()
return self._check_response(response.json())
async def _aresponse_json(self, url: str) -> dict:
"""Use aiohttp to request DataForSEO SERP API and return results async."""
request_details = self._prepare_request(url)
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(
request_details["url"],
headers=request_details["headers"],
json=request_details["data"],
) as response:
res = await response.json()
else:
async with self.aiosession.post(
request_details["url"],
headers=request_details["headers"],
json=request_details["data"],
) as response:
res = await response.json()
return self._check_response(res)
def _filter_results(self, res: dict) -> list:
output = []
types = self.json_result_types if self.json_result_types is not None else []
for task in res.get("tasks", []):
for result in task.get("result", []):
for item in result.get("items", []):
if len(types) == 0 or item.get("type", "") in types:
self._cleanup_unnecessary_items(item)
if len(item) != 0:
output.append(item)
if self.top_count is not None and len(output) >= self.top_count:
break
return output
def _cleanup_unnecessary_items(self, d: dict) -> dict:
fields = self.json_result_fields if self.json_result_fields is not None else []
if len(fields) > 0:
for k, v in list(d.items()):
if isinstance(v, dict):
self._cleanup_unnecessary_items(v)
if len(v) == 0:
del d[k]
elif k not in fields:
del d[k]
if "xpath" in d:
del d["xpath"]
if "position" in d:
del d["position"]
if "rectangle" in d:
del d["rectangle"]
for k, v in list(d.items()):
if isinstance(v, dict):
self._cleanup_unnecessary_items(v)
return d
def _process_response(self, res: dict) -> str:
"""Process response from DataForSEO SERP API."""
toret = "No good search result found"
for task in res.get("tasks", []):
for result in task.get("result", []):
item_types = result.get("item_types")
items = result.get("items", [])
if "answer_box" in item_types:
toret = next(
item for item in items if item.get("type") == "answer_box"
).get("text")
elif "knowledge_graph" in item_types:
toret = next(
item for item in items if item.get("type") == "knowledge_graph"
).get("description")
elif "featured_snippet" in item_types:
toret = next(
item for item in items if item.get("type") == "featured_snippet"
).get("description")
elif "shopping" in item_types:
toret = next(
item for item in items if item.get("type") == "shopping"
).get("price")
elif "organic" in item_types:
toret = next(
item for item in items if item.get("type") == "organic"
).get("description")
if toret:
break
return toret

View File

@@ -0,0 +1,68 @@
"""Util that calls Dataherald."""
from typing import Any, Dict, Optional
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, ConfigDict, model_validator
class DataheraldAPIWrapper(BaseModel):
"""Wrapper for Dataherald.
Docs for using:
1. Go to dataherald and sign up
2. Create an API key
3. Save your API key into DATAHERALD_API_KEY env variable
4. pip install dataherald
"""
dataherald_client: Any = None #: :meta private:
db_connection_id: str
dataherald_api_key: Optional[str] = None
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key and python package exists in environment."""
dataherald_api_key = get_from_dict_or_env(
values, "dataherald_api_key", "DATAHERALD_API_KEY"
)
values["dataherald_api_key"] = dataherald_api_key
try:
import dataherald
except ImportError:
raise ImportError(
"dataherald is not installed. "
"Please install it with `pip install dataherald`"
)
client = dataherald.Dataherald(api_key=dataherald_api_key)
values["dataherald_client"] = client
return values
def run(self, prompt: str) -> str:
"""Generate a sql query through Dataherald and parse result."""
from dataherald.types.sql_generation_create_params import Prompt
prompt_obj = Prompt(text=prompt, db_connection_id=self.db_connection_id)
res = self.dataherald_client.sql_generations.create(prompt=prompt_obj)
try:
answer = res.sql
if not answer:
# We don't want to return the assumption alone if answer is empty
return "No answer"
else:
return f"Answer: {answer}"
except StopIteration:
return "Dataherald wasn't able to answer it"

View File

@@ -0,0 +1,95 @@
import logging
from typing import Any, Dict, List, Optional, Union
logger = logging.getLogger(__name__)
class DriaAPIWrapper:
"""Wrapper around Dria API.
This wrapper facilitates interactions with Dria's vector search
and retrieval services, including creating knowledge bases, inserting data,
and fetching search results.
Attributes:
api_key: Your API key for accessing Dria.
contract_id: The contract ID of the knowledge base to interact with.
top_n: Number of top results to fetch for a search.
"""
def __init__(
self, api_key: str, contract_id: Optional[str] = None, top_n: int = 10
):
try:
from dria import Dria, Models
except ImportError:
logger.error(
"""Dria is not installed. Please install Dria to use this wrapper.
You can install Dria using the following command:
pip install dria
"""
)
return
self.api_key = api_key
self.models = Models
self.contract_id = contract_id
self.top_n = top_n
self.dria_client = Dria(api_key=self.api_key)
if self.contract_id:
self.dria_client.set_contract(self.contract_id)
def create_knowledge_base(
self,
name: str,
description: str,
category: str,
embedding: str,
) -> str:
"""Create a new knowledge base."""
contract_id = self.dria_client.create(
name=name, embedding=embedding, category=category, description=description
)
logger.info(f"Knowledge base created with ID: {contract_id}")
self.contract_id = contract_id
return contract_id
def insert_data(self, data: List[Dict[str, Any]]) -> str:
"""Insert data into the knowledge base."""
response = self.dria_client.insert_text(data)
logger.info(f"Data inserted: {response}")
return response
def search(self, query: str) -> List[Dict[str, Any]]:
"""Perform a text-based search."""
results = self.dria_client.search(query, top_n=self.top_n)
logger.info(f"Search results: {results}")
return results
def query_with_vector(self, vector: List[float]) -> List[Dict[str, Any]]:
"""Perform a vector-based query."""
vector_query_results = self.dria_client.query(vector, top_n=self.top_n)
logger.info(f"Vector query results: {vector_query_results}")
return vector_query_results
def run(self, query: Union[str, List[float]]) -> Optional[List[Dict[str, Any]]]:
"""Method to handle both text-based searches and vector-based queries.
Args:
query: A string for text-based search or a list of floats for
vector-based query.
Returns:
The search or query results from Dria.
"""
if isinstance(query, str):
return self.search(query)
elif isinstance(query, list) and all(isinstance(item, float) for item in query):
return self.query_with_vector(query)
else:
logger.error(
"""Invalid query type. Please provide a string for text search or a
list of floats for vector query."""
)
return None

View File

@@ -0,0 +1,178 @@
"""Util that calls DuckDuckGo Search.
No setup required. Free.
https://pypi.org/project/duckduckgo-search/
"""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, model_validator
class DuckDuckGoSearchAPIWrapper(BaseModel):
"""Wrapper for DuckDuckGo Search API.
Free and does not require any setup.
"""
region: Optional[str] = "wt-wt"
"""
See https://pypi.org/project/duckduckgo-search/#regions
"""
safesearch: str = "moderate"
"""
Options: strict, moderate, off
"""
time: Optional[str] = "y"
"""
Options: d, w, m, y
"""
max_results: int = 5
backend: str = "auto"
"""
Options: auto, html, lite
"""
source: str = "text"
"""
Options: text, news, images
"""
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that python package exists in environment."""
try:
from ddgs import DDGS # noqa: F401
except ImportError:
raise ImportError(
"Could not import ddgs python package. "
"Please install it with `pip install -U ddgs`."
)
return values
def _ddgs_text(
self, query: str, max_results: Optional[int] = None
) -> List[Dict[str, str]]:
"""Run query through DuckDuckGo text search and return results."""
from ddgs import DDGS
with DDGS() as ddgs:
ddgs_gen = ddgs.text(
query,
region=self.region,
safesearch=self.safesearch,
timelimit=self.time,
max_results=max_results or self.max_results,
backend=self.backend,
)
if ddgs_gen:
return [r for r in ddgs_gen]
return []
def _ddgs_news(
self, query: str, max_results: Optional[int] = None
) -> List[Dict[str, str]]:
"""Run query through DuckDuckGo news search and return results."""
from ddgs import DDGS
with DDGS() as ddgs:
ddgs_gen = ddgs.news(
query,
region=self.region,
safesearch=self.safesearch,
timelimit=self.time,
max_results=max_results or self.max_results,
)
if ddgs_gen:
return [r for r in ddgs_gen]
return []
def _ddgs_images(
self, query: str, max_results: Optional[int] = None
) -> List[Dict[str, str]]:
"""Run query through DuckDuckGo image search and return results."""
from ddgs import DDGS
with DDGS() as ddgs:
ddgs_gen = ddgs.images(
query,
region=self.region,
safesearch=self.safesearch,
max_results=max_results or self.max_results,
)
if ddgs_gen:
return [r for r in ddgs_gen]
return []
def run(self, query: str) -> str:
"""Run query through DuckDuckGo and return concatenated results."""
if self.source == "text":
results = self._ddgs_text(query)
elif self.source == "news":
results = self._ddgs_news(query)
elif self.source == "images":
results = self._ddgs_images(query)
else:
results = []
if not results:
return "No good DuckDuckGo Search Result was found"
return " ".join(r["body"] for r in results)
def results(
self, query: str, max_results: int, source: Optional[str] = None
) -> List[Dict[str, str]]:
"""Run query through DuckDuckGo and return metadata.
Args:
query: The query to search for.
max_results: The number of results to return.
source: The source to look from.
Returns:
A list of dictionaries with the following keys:
snippet - The description of the result.
title - The title of the result.
link - The link to the result.
"""
source = source or self.source
if source == "text":
results = [
{"snippet": r["body"], "title": r["title"], "link": r["href"]}
for r in self._ddgs_text(query, max_results=max_results)
]
elif source == "news":
results = [
{
"snippet": r["body"],
"title": r["title"],
"link": r["url"],
"date": r["date"],
"source": r["source"],
}
for r in self._ddgs_news(query, max_results=max_results)
]
elif source == "images":
results = [
{
"title": r["title"],
"thumbnail": r["thumbnail"],
"image": r["image"],
"url": r["url"],
"height": r["height"],
"width": r["width"],
"source": r["source"],
}
for r in self._ddgs_images(query, max_results=max_results)
]
else:
results = []
if results is None:
results = [{"Result": "No good DuckDuckGo Search Result was found"}]
return results

View File

@@ -0,0 +1,147 @@
"""
Util that calls several of financial datasets stock market REST APIs.
Docs: https://docs.financialdatasets.ai/
"""
import json
from typing import Any, List, Optional
import requests
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel
FINANCIAL_DATASETS_BASE_URL = "https://api.financialdatasets.ai/"
class FinancialDatasetsAPIWrapper(BaseModel):
"""Wrapper for financial datasets API."""
financial_datasets_api_key: Optional[str] = None
def __init__(self, **data: Any):
super().__init__(**data)
self.financial_datasets_api_key = get_from_dict_or_env(
data, "financial_datasets_api_key", "FINANCIAL_DATASETS_API_KEY"
)
@property
def _api_key(self) -> str:
if self.financial_datasets_api_key is None:
raise ValueError(
"API key is required for the FinancialDatasetsAPIWrapper. "
"Please provide the API key by either:\n"
"1. Manually specifying it when initializing the wrapper: "
"FinancialDatasetsAPIWrapper(financial_datasets_api_key='your_api_key')\n"
"2. Setting it as an environment variable: FINANCIAL_DATASETS_API_KEY"
)
return self.financial_datasets_api_key
def get_income_statements(
self,
ticker: str,
period: str,
limit: Optional[int],
) -> Optional[dict]:
"""
Get the income statements for a stock `ticker` over a `period` of time.
:param ticker: the stock ticker
:param period: the period of time to get the balance sheets for.
Possible values are: annual, quarterly, ttm.
:param limit: the number of results to return, default is 10
:return: a list of income statements
"""
url = (
f"{FINANCIAL_DATASETS_BASE_URL}financials/income-statements/"
f"?ticker={ticker}"
f"&period={period}"
f"&limit={limit if limit else 10}"
)
# Add the api key to the headers
headers = {"X-API-KEY": self._api_key}
# Execute the request
response = requests.get(url, headers=headers)
data = response.json()
return data.get("income_statements", None)
def get_balance_sheets(
self,
ticker: str,
period: str,
limit: Optional[int],
) -> List[dict]:
"""
Get the balance sheets for a stock `ticker` over a `period` of time.
:param ticker: the stock ticker
:param period: the period of time to get the balance sheets for.
Possible values are: annual, quarterly, ttm.
:param limit: the number of results to return, default is 10
:return: a list of balance sheets
"""
url = (
f"{FINANCIAL_DATASETS_BASE_URL}financials/balance-sheets/"
f"?ticker={ticker}"
f"&period={period}"
f"&limit={limit if limit else 10}"
)
# Add the api key to the headers
headers = {"X-API-KEY": self._api_key}
# Execute the request
response = requests.get(url, headers=headers)
data = response.json()
return data.get("balance_sheets", None)
def get_cash_flow_statements(
self,
ticker: str,
period: str,
limit: Optional[int],
) -> List[dict]:
"""
Get the cash flow statements for a stock `ticker` over a `period` of time.
:param ticker: the stock ticker
:param period: the period of time to get the balance sheets for.
Possible values are: annual, quarterly, ttm.
:param limit: the number of results to return, default is 10
:return: a list of cash flow statements
"""
url = (
f"{FINANCIAL_DATASETS_BASE_URL}financials/cash-flow-statements/"
f"?ticker={ticker}"
f"&period={period}"
f"&limit={limit if limit else 10}"
)
# Add the api key to the headers
headers = {"X-API-KEY": self._api_key}
# Execute the request
response = requests.get(url, headers=headers)
data = response.json()
return data.get("cash_flow_statements", None)
def run(self, mode: str, ticker: str, **kwargs: Any) -> str:
if mode == "get_income_statements":
period = kwargs.get("period", "annual")
limit = kwargs.get("limit", 10)
return json.dumps(self.get_income_statements(ticker, period, limit))
elif mode == "get_balance_sheets":
period = kwargs.get("period", "annual")
limit = kwargs.get("limit", 10)
return json.dumps(self.get_balance_sheets(ticker, period, limit))
elif mode == "get_cash_flow_statements":
period = kwargs.get("period", "annual")
limit = kwargs.get("limit", 10)
return json.dumps(self.get_cash_flow_statements(ticker, period, limit))
else:
raise ValueError(f"Invalid mode {mode} for financial datasets API.")

View File

@@ -0,0 +1,901 @@
"""Util that calls GitHub."""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import requests
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, ConfigDict, model_validator
if TYPE_CHECKING:
from github.Issue import Issue
from github.PullRequest import PullRequest
def _import_tiktoken() -> Any:
"""Import tiktoken."""
try:
import tiktoken
except ImportError:
raise ImportError(
"tiktoken is not installed. Please install it with `pip install tiktoken`"
)
return tiktoken
class GitHubAPIWrapper(BaseModel):
"""Wrapper for GitHub API."""
github: Any = None #: :meta private:
github_repo_instance: Any = None #: :meta private:
github_repository: Optional[str] = None
github_app_id: Optional[str] = None
github_app_private_key: Optional[str] = None
active_branch: Optional[str] = None
github_base_branch: Optional[str] = None
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key and python package exists in environment."""
github_repository = get_from_dict_or_env(
values, "github_repository", "GITHUB_REPOSITORY"
)
github_app_id = get_from_dict_or_env(values, "github_app_id", "GITHUB_APP_ID")
github_app_private_key = get_from_dict_or_env(
values, "github_app_private_key", "GITHUB_APP_PRIVATE_KEY"
)
try:
from github import Auth, GithubIntegration
except ImportError:
raise ImportError(
"PyGithub is not installed. "
"Please install it with `pip install PyGithub`"
)
try:
# interpret the key as a file path
# fallback to interpreting as the key itself
with open(github_app_private_key, "r") as f:
private_key = f.read()
except Exception:
private_key = github_app_private_key
auth = Auth.AppAuth(
github_app_id,
private_key,
)
gi = GithubIntegration(auth=auth)
installation = gi.get_installations()
if not installation:
raise ValueError(
f"Please make sure to install the created github app with id "
f"{github_app_id} on the repo: {github_repository}"
"More instructions can be found at "
"https://docs.github.com/en/apps/using-"
"github-apps/installing-your-own-github-app"
)
try:
installation = installation[0]
except ValueError as e:
raise ValueError(
f"Please make sure to give correct github parameters Error message: {e}"
)
# create a GitHub instance:
g = installation.get_github_for_installation()
repo = g.get_repo(github_repository)
github_base_branch = get_from_dict_or_env(
values,
"github_base_branch",
"GITHUB_BASE_BRANCH",
default=repo.default_branch,
)
active_branch = get_from_dict_or_env(
values,
"active_branch",
"ACTIVE_BRANCH",
default=repo.default_branch,
)
values["github"] = g
values["github_repo_instance"] = repo
values["github_repository"] = github_repository
values["github_app_id"] = github_app_id
values["github_app_private_key"] = github_app_private_key
values["active_branch"] = active_branch
values["github_base_branch"] = github_base_branch
return values
def parse_issues(self, issues: List[Issue]) -> List[dict]:
"""
Extracts title and number from each Issue and puts them in a dictionary
Parameters:
issues(List[Issue]): A list of Github Issue objects
Returns:
List[dict]: A dictionary of issue titles and numbers
"""
parsed = []
for issue in issues:
title = issue.title
number = issue.number
opened_by = issue.user.login if issue.user else None
issue_dict = {"title": title, "number": number}
if opened_by is not None:
issue_dict["opened_by"] = opened_by
parsed.append(issue_dict)
return parsed
def parse_pull_requests(self, pull_requests: List[PullRequest]) -> List[dict]:
"""
Extracts title and number from each Issue and puts them in a dictionary
Parameters:
issues(List[Issue]): A list of Github Issue objects
Returns:
List[dict]: A dictionary of issue titles and numbers
"""
parsed = []
for pr in pull_requests:
parsed.append(
{
"title": pr.title,
"number": pr.number,
"commits": str(pr.commits),
"comments": str(pr.comments),
}
)
return parsed
def get_issues(self) -> str:
"""
Fetches all open issues from the repo excluding pull requests
Returns:
str: A plaintext report containing the number of issues
and each issue's title and number.
"""
issues = self.github_repo_instance.get_issues(state="open")
# Filter out pull requests (part of GH issues object)
issues = [issue for issue in issues if not issue.pull_request]
if issues:
parsed_issues = self.parse_issues(issues)
parsed_issues_str = (
"Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues)
)
return parsed_issues_str
else:
return "No open issues available"
def list_open_pull_requests(self) -> str:
"""
Fetches all open PRs from the repo
Returns:
str: A plaintext report containing the number of PRs
and each PR's title and number.
"""
# issues = self.github_repo_instance.get_issues(state="open")
pull_requests = self.github_repo_instance.get_pulls(state="open")
if pull_requests.totalCount > 0:
parsed_prs = self.parse_pull_requests(pull_requests)
parsed_prs_str = (
"Found " + str(len(parsed_prs)) + " pull requests:\n" + str(parsed_prs)
)
return parsed_prs_str
else:
return "No open pull requests available"
def list_files_in_main_branch(self) -> str:
"""
Fetches all files in the main branch of the repo.
Returns:
str: A plaintext report containing the paths and names of the files.
"""
files: List[str] = []
try:
contents = self.github_repo_instance.get_contents(
"", ref=self.github_base_branch
)
for content in contents:
if content.type == "dir":
files.extend(self._list_files(content.path))
else:
files.append(content.path)
if files:
files_str = "\n".join(files)
return f"Found {len(files)} files in the main branch:\n{files_str}"
else:
return "No files found in the main branch"
except Exception as e:
return str(e)
def set_active_branch(self, branch_name: str) -> str:
"""Equivalent to `git checkout branch_name` for this Agent.
Clones formatting from Github.
Returns an Error (as a string) if branch doesn't exist.
"""
curr_branches = [
branch.name for branch in self.github_repo_instance.get_branches()
]
if branch_name in curr_branches:
self.active_branch = branch_name
return f"Switched to branch `{branch_name}`"
else:
return (
f"Error {branch_name} does not exist,"
f"in repo with current branches: {str(curr_branches)}"
)
def list_branches_in_repo(self) -> str:
"""
Fetches a list of all branches in the repository.
Returns:
str: A plaintext report containing the names of the branches.
"""
try:
branches = [
branch.name for branch in self.github_repo_instance.get_branches()
]
if branches:
branches_str = "\n".join(branches)
return (
f"Found {len(branches)} branches in the repository:\n{branches_str}"
)
else:
return "No branches found in the repository"
except Exception as e:
return str(e)
def create_branch(self, proposed_branch_name: str) -> str:
"""
Create a new branch, and set it as the active bot branch.
Equivalent to `git switch -c proposed_branch_name`
If the proposed branch already exists, we append _v1 then _v2...
until a unique name is found.
Returns:
str: A plaintext success message.
"""
from github import GithubException
i = 0
new_branch_name = proposed_branch_name
base_branch = self.github_repo_instance.get_branch(
self.github_repo_instance.default_branch
)
for i in range(1000):
try:
self.github_repo_instance.create_git_ref(
ref=f"refs/heads/{new_branch_name}", sha=base_branch.commit.sha
)
self.active_branch = new_branch_name
return (
f"Branch '{new_branch_name}' "
"created successfully, and set as current active branch."
)
except GithubException as e:
if e.status == 422 and "Reference already exists" in e.data["message"]:
i += 1
new_branch_name = f"{proposed_branch_name}_v{i}"
else:
# Handle any other exceptions
print(f"Failed to create branch. Error: {e}") # noqa: T201
raise Exception(
"Unable to create branch name from proposed_branch_name: "
f"{proposed_branch_name}"
)
return (
"Unable to create branch. "
"At least 1000 branches exist with named derived from "
f"proposed_branch_name: `{proposed_branch_name}`"
)
def list_files_in_bot_branch(self) -> str:
"""
Fetches all files in the active branch of the repo,
the branch the bot uses to make changes.
Returns:
str: A plaintext list containing the filepaths in the branch.
"""
files: List[str] = []
try:
contents = self.github_repo_instance.get_contents(
"", ref=self.active_branch
)
for content in contents:
if content.type == "dir":
files.extend(self._list_files(content.path))
else:
files.append(content.path)
if files:
files_str = "\n".join(files)
return (
f"Found {len(files)} files in branch `{self.active_branch}`:\n"
f"{files_str}"
)
else:
return f"No files found in branch: `{self.active_branch}`"
except Exception as e:
return f"Error: {e}"
def get_files_from_directory(self, directory_path: str) -> str:
"""
Recursively fetches files from a directory in the repo.
Parameters:
directory_path (str): Path to the directory
Returns:
str: List of file paths, or an error message.
"""
from github import GithubException
try:
return str(self._list_files(directory_path))
except GithubException as e:
return f"Error: status code {e.status}, {e.message}"
def _list_files(self, directory_path: str) -> List[str]:
files: List[str] = []
contents = self.github_repo_instance.get_contents(
directory_path, ref=self.active_branch
)
for content in contents:
if content.type == "dir":
files.extend(self._list_files(content.path))
else:
files.append(content.path)
return files
def get_issue(self, issue_number: int) -> Dict[str, Any]:
"""
Fetches a specific issue and its first 10 comments
Parameters:
issue_number(int): The number for the github issue
Returns:
`dict` containing the issue's title, body, comments as a string, and the
username of the user who opened the issue
"""
issue = self.github_repo_instance.get_issue(number=issue_number)
page = 0
comments: List[dict] = []
while len(comments) <= 10:
comments_page = issue.get_comments().get_page(page)
if len(comments_page) == 0:
break
for comment in comments_page:
comments.append({"body": comment.body, "user": comment.user.login})
page += 1
opened_by = None
if issue.user and issue.user.login:
opened_by = issue.user.login
return {
"number": issue_number,
"title": issue.title,
"body": issue.body,
"comments": str(comments),
"opened_by": str(opened_by),
}
def list_pull_request_files(self, pr_number: int) -> List[Dict[str, Any]]:
"""Fetches the full text of all files in a PR. Truncates after first 3k tokens.
# TODO: Enhancement to summarize files with ctags if they're getting long.
Args:
pr_number(int): The number of the pull request on Github
Returns:
`dict` containing the issue's title, body, and comments as a string
"""
tiktoken = _import_tiktoken()
MAX_TOKENS_FOR_FILES = 3_000
pr_files = []
pr = self.github_repo_instance.get_pull(number=int(pr_number))
total_tokens = 0
page = 0
while True: # or while (total_tokens + tiktoken()) < MAX_TOKENS_FOR_FILES:
files_page = pr.get_files().get_page(page)
if len(files_page) == 0:
break
for file in files_page:
try:
file_metadata_response = requests.get(file.contents_url)
if file_metadata_response.status_code == 200:
download_url = json.loads(file_metadata_response.text)[
"download_url"
]
else:
print(f"Failed to download file: {file.contents_url}, skipping") # noqa: T201
continue
file_content_response = requests.get(download_url)
if file_content_response.status_code == 200:
# Save the content as a UTF-8 string
file_content = file_content_response.text
else:
print( # noqa: T201
"Failed downloading file content "
f"(Error {file_content_response.status_code}). Skipping"
)
continue
file_tokens = len(
tiktoken.get_encoding("cl100k_base").encode(
file_content + file.filename + "file_name file_contents"
)
)
if (total_tokens + file_tokens) < MAX_TOKENS_FOR_FILES:
pr_files.append(
{
"filename": file.filename,
"contents": file_content,
"additions": file.additions,
"deletions": file.deletions,
}
)
total_tokens += file_tokens
except Exception as e:
print(f"Error when reading files from a PR on github. {e}") # noqa: T201
page += 1
return pr_files
def get_pull_request(self, pr_number: int) -> Dict[str, Any]:
"""
Fetches a specific pull request and its first 10 comments,
limited by max_tokens.
Parameters:
pr_number(int): The number for the Github pull
max_tokens(int): The maximum number of tokens in the response
Returns:
`dict` containing the pull's title, body, and comments as a string
"""
max_tokens = 2_000
pull = self.github_repo_instance.get_pull(number=pr_number)
total_tokens = 0
def get_tokens(text: str) -> int:
tiktoken = _import_tiktoken()
return len(tiktoken.get_encoding("cl100k_base").encode(text))
def add_to_dict(data_dict: Dict[str, Any], key: str, value: str) -> None:
nonlocal total_tokens # Declare total_tokens as nonlocal
tokens = get_tokens(value)
if total_tokens + tokens <= max_tokens:
data_dict[key] = value
total_tokens += tokens # Now this will modify the outer variable
response_dict: Dict[str, str] = {}
add_to_dict(response_dict, "title", pull.title)
add_to_dict(response_dict, "number", str(pr_number))
add_to_dict(response_dict, "body", pull.body if pull.body else "")
comments: List[str] = []
page = 0
while len(comments) <= 10:
comments_page = pull.get_issue_comments().get_page(page)
if len(comments_page) == 0:
break
for comment in comments_page:
comment_str = str({"body": comment.body, "user": comment.user.login})
if total_tokens + get_tokens(comment_str) > max_tokens:
break
comments.append(comment_str)
total_tokens += get_tokens(comment_str)
page += 1
add_to_dict(response_dict, "comments", str(comments))
commits: List[str] = []
page = 0
while len(commits) <= 10:
commits_page = pull.get_commits().get_page(page)
if len(commits_page) == 0:
break
for commit in commits_page:
commit_str = str({"message": commit.commit.message})
if total_tokens + get_tokens(commit_str) > max_tokens:
break
commits.append(commit_str)
total_tokens += get_tokens(commit_str)
page += 1
add_to_dict(response_dict, "commits", str(commits))
return response_dict
def create_pull_request(self, pr_query: str) -> str:
"""
Makes a pull request from the bot's branch to the base branch
Parameters:
pr_query(str): a string which contains the PR title
and the PR body. The title is the first line
in the string, and the body are the rest of the string.
For example, "Updated README\nmade changes to add info"
Returns:
str: A success or failure message
"""
if self.github_base_branch == self.active_branch:
return """Cannot make a pull request because
commits are already in the main or master branch."""
else:
try:
title = pr_query.split("\n")[0]
body = pr_query[len(title) + 2 :]
pr = self.github_repo_instance.create_pull(
title=title,
body=body,
head=self.active_branch,
base=self.github_base_branch,
)
return f"Successfully created PR number {pr.number}"
except Exception as e:
return "Unable to make pull request due to error:\n" + str(e)
def comment_on_issue(self, comment_query: str) -> str:
"""
Adds a comment to a github issue
Parameters:
comment_query(str): a string which contains the issue number,
two newlines, and the comment.
for example: "1\n\nWorking on it now"
adds the comment "working on it now" to issue 1
Returns:
str: A success or failure message
"""
issue_number = int(comment_query.split("\n\n")[0])
comment = comment_query[len(str(issue_number)) + 2 :]
try:
issue = self.github_repo_instance.get_issue(number=issue_number)
issue.create_comment(comment)
return "Commented on issue " + str(issue_number)
except Exception as e:
return "Unable to make comment due to error:\n" + str(e)
def create_file(self, file_query: str) -> str:
"""
Creates a new file on the Github repo
Parameters:
file_query(str): a string which contains the file path
and the file contents. The file path is the first line
in the string, and the contents are the rest of the string.
For example, "hello_world.md\n# Hello World!"
Returns:
str: A success or failure message
"""
if self.active_branch == self.github_base_branch:
return (
"You're attempting to commit to the directly to the"
f"{self.github_base_branch} branch, which is protected. "
"Please create a new branch and try again."
)
file_path = file_query.split("\n")[0]
file_contents = file_query[len(file_path) + 2 :]
try:
try:
file = self.github_repo_instance.get_contents(
file_path, ref=self.active_branch
)
if file:
return (
f"File already exists at `{file_path}` "
f"on branch `{self.active_branch}`. You must use "
"`update_file` to modify it."
)
except Exception:
# expected behavior, file shouldn't exist yet
pass
self.github_repo_instance.create_file(
path=file_path,
message="Create " + file_path,
content=file_contents,
branch=self.active_branch,
)
return "Created file " + file_path
except Exception as e:
return "Unable to make file due to error:\n" + str(e)
def read_file(self, file_path: str) -> str:
"""
Read a file from this agent's branch, defined by self.active_branch,
which supports PR branches.
Parameters:
file_path(str): the file path
Returns:
str: The file decoded as a string, or an error message if not found
"""
try:
file = self.github_repo_instance.get_contents(
file_path, ref=self.active_branch
)
return file.decoded_content.decode("utf-8")
except Exception as e:
return (
f"File not found `{file_path}` on branch"
f"`{self.active_branch}`. Error: {str(e)}"
)
def update_file(self, file_query: str) -> str:
"""
Updates a file with new content.
Parameters:
file_query(str): Contains the file path and the file contents.
The old file contents is wrapped in OLD <<<< and >>>> OLD
The new file contents is wrapped in NEW <<<< and >>>> NEW
For example:
/test/hello.txt
OLD <<<<
Hello Earth!
>>>> OLD
NEW <<<<
Hello Mars!
>>>> NEW
Returns:
A success or failure message
"""
if self.active_branch == self.github_base_branch:
return (
"You're attempting to commit to the directly"
f"to the {self.github_base_branch} branch, which is protected. "
"Please create a new branch and try again."
)
try:
file_path: str = file_query.split("\n")[0]
old_file_contents = (
file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip()
)
new_file_contents = (
file_query.split("NEW <<<<")[1].split(">>>> NEW")[0].strip()
)
file_content = self.read_file(file_path)
updated_file_content = file_content.replace(
old_file_contents, new_file_contents
)
if file_content == updated_file_content:
return (
"File content was not updated because old content was not found."
"It may be helpful to use the read_file action to get "
"the current file contents."
)
self.github_repo_instance.update_file(
path=file_path,
message="Update " + str(file_path),
content=updated_file_content,
branch=self.active_branch,
sha=self.github_repo_instance.get_contents(
file_path, ref=self.active_branch
).sha,
)
return "Updated file " + str(file_path)
except Exception as e:
return "Unable to update file due to error:\n" + str(e)
def delete_file(self, file_path: str) -> str:
"""
Deletes a file from the repo
Parameters:
file_path(str): Where the file is
Returns:
str: Success or failure message
"""
if self.active_branch == self.github_base_branch:
return (
"You're attempting to commit to the directly"
f"to the {self.github_base_branch} branch, which is protected. "
"Please create a new branch and try again."
)
try:
self.github_repo_instance.delete_file(
path=file_path,
message="Delete " + file_path,
branch=self.active_branch,
sha=self.github_repo_instance.get_contents(
file_path, ref=self.active_branch
).sha,
)
return "Deleted file " + file_path
except Exception as e:
return "Unable to delete file due to error:\n" + str(e)
def search_issues_and_prs(self, query: str) -> str:
"""
Searches issues and pull requests in the repository.
Parameters:
query(str): The search query
Returns:
str: A string containing the first 5 issues and pull requests
"""
search_result = self.github.search_issues(query, repo=self.github_repository)
max_items = min(5, search_result.totalCount)
results = [f"Top {max_items} results:"]
for issue in search_result[:max_items]:
results.append(
f"Title: {issue.title}, Number: {issue.number}, State: {issue.state}"
)
return "\n".join(results)
def search_code(self, query: str) -> str:
"""
Searches code in the repository.
# Todo: limit total tokens returned...
Parameters:
query(str): The search query
Returns:
str: A string containing, at most, the top 5 search results
"""
search_result = self.github.search_code(
query=query, repo=self.github_repository
)
if search_result.totalCount == 0:
return "0 results found."
max_results = min(5, search_result.totalCount)
results = [f"Showing top {max_results} of {search_result.totalCount} results:"]
count = 0
for code in search_result:
if count >= max_results:
break
# Get the file content using the PyGithub get_contents method
file_content = self.github_repo_instance.get_contents(
code.path, ref=self.active_branch
).decoded_content.decode()
results.append(
f"Filepath: `{code.path}`\nFile contents: {file_content}\n<END OF FILE>"
)
count += 1
return "\n".join(results)
def create_review_request(self, reviewer_username: str) -> str:
"""
Creates a review request on *THE* open pull request
that matches the current active_branch.
Parameters:
reviewer_username(str): The username of the person who is being requested
Returns:
str: A message confirming the creation of the review request
"""
pull_requests = self.github_repo_instance.get_pulls(
state="open", sort="created"
)
# find PR against active_branch
pr = next(
(pr for pr in pull_requests if pr.head.ref == self.active_branch), None
)
if pr is None:
return (
"No open pull request found for the "
f"current branch `{self.active_branch}`"
)
try:
pr.create_review_request(reviewers=[reviewer_username])
return (
f"Review request created for user {reviewer_username} "
f"on PR #{pr.number}"
)
except Exception as e:
return f"Failed to create a review request with error {e}"
def get_latest_release(self) -> str:
"""
Fetches the latest release of the repository.
Returns:
str: The latest release
"""
release = self.github_repo_instance.get_latest_release()
return (
f"Latest title: {release.title} "
f"tag: {release.tag_name} "
f"body: {release.body}"
)
def get_releases(self) -> str:
"""
Fetches all releases of the repository.
Returns:
str: The releases
"""
releases = self.github_repo_instance.get_releases()
max_results = min(5, releases.totalCount)
results = [f"Top {max_results} results:"]
for release in releases[:max_results]:
results.append(
f"Title: {release.title}, Tag: {release.tag_name}, Body: {release.body}"
)
return "\n".join(results)
def get_release(self, tag_name: str) -> str:
"""
Fetches a specific release of the repository.
Parameters:
tag_name(str): The tag name of the release
Returns:
str: The release
"""
release = self.github_repo_instance.get_release(tag_name)
return f"Release: {release.title} tag: {release.tag_name} body: {release.body}"
def run(self, mode: str, query: str) -> str:
if mode == "get_issue":
return json.dumps(self.get_issue(int(query)))
elif mode == "get_pull_request":
return json.dumps(self.get_pull_request(int(query)))
elif mode == "list_pull_request_files":
return json.dumps(self.list_pull_request_files(int(query)))
elif mode == "get_issues":
return self.get_issues()
elif mode == "comment_on_issue":
return self.comment_on_issue(query)
elif mode == "create_file":
return self.create_file(query)
elif mode == "create_pull_request":
return self.create_pull_request(query)
elif mode == "read_file":
return self.read_file(query)
elif mode == "update_file":
return self.update_file(query)
elif mode == "delete_file":
return self.delete_file(query)
elif mode == "list_open_pull_requests":
return self.list_open_pull_requests()
elif mode == "list_files_in_main_branch":
return self.list_files_in_main_branch()
elif mode == "list_files_in_bot_branch":
return self.list_files_in_bot_branch()
elif mode == "list_branches_in_repo":
return self.list_branches_in_repo()
elif mode == "set_active_branch":
return self.set_active_branch(query)
elif mode == "create_branch":
return self.create_branch(query)
elif mode == "get_files_from_directory":
return self.get_files_from_directory(query)
elif mode == "search_issues_and_prs":
return self.search_issues_and_prs(query)
elif mode == "search_code":
return self.search_code(query)
elif mode == "create_review_request":
return self.create_review_request(query)
elif mode == "get_latest_release":
return self.get_latest_release()
elif mode == "get_releases":
return self.get_releases()
elif mode == "get_release":
return self.get_release(query)
else:
raise ValueError("Invalid mode" + mode)

View File

@@ -0,0 +1,518 @@
"""Util that calls gitlab."""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, ConfigDict, model_validator
if TYPE_CHECKING:
from gitlab.v4.objects import Issue
class GitLabAPIWrapper(BaseModel):
"""Wrapper for GitLab API."""
gitlab: Any = None #: :meta private:
gitlab_repo_instance: Any = None #: :meta private:
gitlab_url: Optional[str] = None
"""The url of the GitLab instance."""
gitlab_repository: Optional[str] = None
"""The name of the GitLab repository, in the form {username}/{repo-name}."""
gitlab_personal_access_token: Optional[str] = None
"""Personal access token for the GitLab service, used for authentication."""
gitlab_branch: Optional[str] = None
"""The specific branch in the GitLab repository where the bot will make
its commits. Defaults to 'main'.
"""
gitlab_base_branch: Optional[str] = None
"""The base branch in the GitLab repository, used for comparisons.
Usually 'main' or 'master'. Defaults to 'main'.
"""
model_config = ConfigDict(
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key and python package exists in environment."""
gitlab_url = get_from_dict_or_env(
values, "gitlab_url", "GITLAB_URL", default="https://gitlab.com"
)
gitlab_repository = get_from_dict_or_env(
values, "gitlab_repository", "GITLAB_REPOSITORY"
)
gitlab_personal_access_token = get_from_dict_or_env(
values, "gitlab_personal_access_token", "GITLAB_PERSONAL_ACCESS_TOKEN"
)
gitlab_branch = get_from_dict_or_env(
values, "gitlab_branch", "GITLAB_BRANCH", default="main"
)
gitlab_base_branch = get_from_dict_or_env(
values, "gitlab_base_branch", "GITLAB_BASE_BRANCH", default="main"
)
try:
import gitlab
except ImportError:
raise ImportError(
"python-gitlab is not installed. "
"Please install it with `pip install python-gitlab`"
)
g = gitlab.Gitlab(
url=gitlab_url,
private_token=gitlab_personal_access_token,
keep_base_url=True,
)
g.auth()
values["gitlab"] = g
values["gitlab_repo_instance"] = g.projects.get(gitlab_repository)
values["gitlab_url"] = gitlab_url
values["gitlab_repository"] = gitlab_repository
values["gitlab_personal_access_token"] = gitlab_personal_access_token
values["gitlab_branch"] = gitlab_branch
values["gitlab_base_branch"] = gitlab_base_branch
return values
def parse_issues(self, issues: List[Issue]) -> List[dict]:
"""
Extracts title and number from each Issue and puts them in a dictionary
Parameters:
issues(List[Issue]): A list of gitlab Issue objects
Returns:
List[dict]: A dictionary of issue titles and numbers
"""
parsed = []
for issue in issues:
title = issue.title
number = issue.iid
parsed.append({"title": title, "number": number})
return parsed
def get_issues(self) -> str:
"""
Fetches all open issues from the repo
Returns:
str: A plaintext report containing the number of issues
and each issue's title and number.
"""
issues = self.gitlab_repo_instance.issues.list(state="opened")
if len(issues) > 0:
parsed_issues = self.parse_issues(issues)
parsed_issues_str = (
"Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues)
)
return parsed_issues_str
else:
return "No open issues available"
def get_issue(self, issue_number: int) -> Dict[str, Any]:
"""
Fetches a specific issue and its first 10 comments
Parameters:
issue_number(int): The number for the gitlab issue
Returns:
`dict` containing the issue's title, body, and comments as a string
"""
issue = self.gitlab_repo_instance.issues.get(issue_number)
page = 0
comments: List[dict] = []
while len(comments) <= 10:
comments_page = issue.notes.list(page=page)
if len(comments_page) == 0:
break
for comment in comments_page:
comment = issue.notes.get(comment.id)
comments.append(
{
"body": comment.body,
"user": comment.author["username"],
}
)
page += 1
return {
"title": issue.title,
"body": issue.description,
"comments": str(comments),
}
def create_pull_request(self, pr_query: str) -> str:
"""
Makes a pull request from the bot's branch to the base branch
Parameters:
pr_query(str): a string which contains the PR title
and the PR body. The title is the first line
in the string, and the body are the rest of the string.
For example, "Updated README\nmade changes to add info"
Returns:
str: A success or failure message
"""
if self.gitlab_base_branch == self.gitlab_branch:
return """Cannot make a pull request because
commits are already in the master branch"""
else:
try:
title = pr_query.split("\n")[0]
body = pr_query[len(title) + 2 :]
pr = self.gitlab_repo_instance.mergerequests.create(
{
"source_branch": self.gitlab_branch,
"target_branch": self.gitlab_base_branch,
"title": title,
"description": body,
"labels": ["created-by-agent"],
}
)
return f"Successfully created PR number {pr.iid}"
except Exception as e:
return "Unable to make pull request due to error:\n" + str(e)
def comment_on_issue(self, comment_query: str) -> str:
"""
Adds a comment to a gitlab issue
Parameters:
comment_query(str): a string which contains the issue number,
two newlines, and the comment.
for example: "1\n\nWorking on it now"
adds the comment "working on it now" to issue 1
Returns:
str: A success or failure message
"""
issue_number = int(comment_query.split("\n\n")[0])
comment = comment_query[len(str(issue_number)) + 2 :]
try:
issue = self.gitlab_repo_instance.issues.get(issue_number)
issue.notes.create({"body": comment})
return "Commented on issue " + str(issue_number)
except Exception as e:
return "Unable to make comment due to error:\n" + str(e)
def create_file(self, file_query: str) -> str:
"""
Creates a new file on the gitlab repo
Parameters:
file_query(str): a string which contains the file path
and the file contents. The file path is the first line
in the string, and the contents are the rest of the string.
For example, "hello_world.md\n# Hello World!"
Returns:
str: A success or failure message
"""
if self.gitlab_branch == self.gitlab_base_branch:
return (
"You're attempting to commit directly"
f"to the {self.gitlab_base_branch} branch, which is protected. "
"Please create a new branch and try again."
)
file_path = file_query.split("\n")[0]
file_contents = file_query[len(file_path) + 2 :]
try:
self.gitlab_repo_instance.files.get(file_path, self.gitlab_branch)
return f"File already exists at {file_path}. Use update_file instead"
except Exception:
data = {
"branch": self.gitlab_branch,
"commit_message": "Create " + file_path,
"file_path": file_path,
"content": file_contents,
}
self.gitlab_repo_instance.files.create(data)
return "Created file " + file_path
def read_file(self, file_path: str) -> str:
"""
Reads a file from the gitlab repo
Parameters:
file_path(str): the file path
Returns:
str: The file decoded as a string
"""
file = self.gitlab_repo_instance.files.get(file_path, self.gitlab_branch)
return file.decode().decode("utf-8")
def update_file(self, file_query: str) -> str:
"""
Updates a file with new content.
Parameters:
file_query(str): Contains the file path and the file contents.
The old file contents is wrapped in OLD <<<< and >>>> OLD
The new file contents is wrapped in NEW <<<< and >>>> NEW
For example:
test/hello.txt
OLD <<<<
Hello Earth!
>>>> OLD
NEW <<<<
Hello Mars!
>>>> NEW
Returns:
A success or failure message
"""
if self.gitlab_branch == self.gitlab_base_branch:
return (
"You're attempting to commit directly"
f"to the {self.gitlab_base_branch} branch, which is protected. "
"Please create a new branch and try again."
)
try:
file_path = file_query.split("\n")[0]
old_file_contents = (
file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip()
)
new_file_contents = (
file_query.split("NEW <<<<")[1].split(">>>> NEW")[0].strip()
)
file_content = self.read_file(file_path)
updated_file_content = file_content.replace(
old_file_contents, new_file_contents
)
if file_content == updated_file_content:
return (
"File content was not updated because old content was not found."
"It may be helpful to use the read_file action to get "
"the current file contents."
)
commit = {
"branch": self.gitlab_branch,
"commit_message": "Create " + file_path,
"actions": [
{
"action": "update",
"file_path": file_path,
"content": updated_file_content,
}
],
}
self.gitlab_repo_instance.commits.create(commit)
return "Updated file " + file_path
except Exception as e:
return "Unable to update file due to error:\n" + str(e)
def delete_file(self, file_path: str) -> str:
"""
Deletes a file from the repo
Parameters:
file_path(str): Where the file is
Returns:
str: Success or failure message
"""
if self.gitlab_branch == self.gitlab_base_branch:
return (
"You're attempting to commit directly"
f"to the {self.gitlab_base_branch} branch, which is protected. "
"Please create a new branch and try again."
)
try:
self.gitlab_repo_instance.files.delete(
file_path, self.gitlab_branch, "Delete " + file_path
)
return "Deleted file " + file_path
except Exception as e:
return "Unable to delete file due to error:\n" + str(e)
def list_files_in_main_branch(self) -> str:
"""
Get the list of files in the main branch of the repository
Returns:
str: A plaintext report containing the list of files
in the repository in the main branch
"""
if self.gitlab_base_branch is None:
return "No base branch set. Please set a base branch."
return self._list_files(self.gitlab_base_branch)
def list_files_in_bot_branch(self) -> str:
"""
Get the list of files in the active branch of the repository
Returns:
str: A plaintext report containing the list of files
in the repository in the active branch
"""
if self.gitlab_branch is None:
return "No active branch set. Please set a branch."
return self._list_files(self.gitlab_branch)
def list_files_from_directory(self, path: str) -> str:
"""
Get the list of files in the active branch of the repository
from a specific directory
Returns:
str: A plaintext report containing the list of files
in the repository in the active branch from the specified directory
"""
if self.gitlab_branch is None:
return "No active branch set. Please set a branch."
return self._list_files(
branch=self.gitlab_branch,
path=path,
)
def _list_files(self, branch: str, path: str = "") -> str:
try:
files = self._get_repository_files(
branch=branch,
path=path,
)
if files:
files_str = "\n".join(files)
return f"Found {len(files)} files in branch `{branch}`:\n{files_str}"
else:
return f"No files found in branch: `{branch}`"
except Exception as e:
return f"Error: {e}"
def _get_repository_files(self, branch: str, path: str = "") -> List[str]:
repo_contents = self.gitlab_repo_instance.repository_tree(ref=branch, path=path)
files: List[str] = []
for content in repo_contents:
if content["type"] == "tree":
files.extend(self._get_repository_files(branch, content["path"]))
else:
files.append(content["path"])
return files
def create_branch(self, proposed_branch_name: str) -> str:
"""
Create a new branch in the repository and set it as the active branch
Parameters:
proposed_branch_name (str): The name of the new branch to be created
Returns:
str: A success or failure message
"""
from gitlab import GitlabCreateError
max_attempts = 100
new_branch_name = proposed_branch_name
for i in range(max_attempts):
try:
response = self.gitlab_repo_instance.branches.create(
{
"branch": new_branch_name,
"ref": self.gitlab_branch,
}
)
self.gitlab_branch = response.name
return (
f"Branch '{response.name}' "
"created successfully, and set as current active branch."
)
except GitlabCreateError as e:
if (
e.response_code == 400
and "Branch already exists" in e.error_message
):
i += 1
new_branch_name = f"{proposed_branch_name}_v{i}"
else:
# Handle any other exceptions
print(f"Failed to create branch. Error: {e}") # noqa: T201
raise Exception(
"Unable to create branch name from proposed_branch_name: "
f"{proposed_branch_name}"
)
return (
f"Unable to create branch. At least {max_attempts} branches exist "
f"with named derived from "
f"proposed_branch_name: `{proposed_branch_name}`"
)
def list_branches_in_repo(self) -> str:
"""
Get the list of branches in the repository
Returns:
str: A plaintext report containing the number of branches
and each branch name
"""
branches = [
branch.name for branch in self.gitlab_repo_instance.branches.list(all=True)
]
if branches:
branches_str = "\n".join(branches)
return (
f"Found {str(len(branches))} branches in the repository:"
f"\n{branches_str}"
)
return "No branches found in the repository"
def set_active_branch(self, branch_name: str) -> str:
"""Equivalent to `git checkout branch_name` for this Agent.
Clones formatting from Gitlab.
Returns an Error (as a string) if branch doesn't exist.
"""
curr_branches = [
branch.name
for branch in self.gitlab_repo_instance.branches.list(
all=True,
)
]
if branch_name in curr_branches:
self.gitlab_branch = branch_name
return f"Switched to branch `{branch_name}`"
else:
return (
f"Error {branch_name} does not exist,"
f"in repo with current branches: {str(curr_branches)}"
)
def run(self, mode: str, query: str) -> str:
if mode == "get_issues":
return self.get_issues()
elif mode == "get_issue":
return json.dumps(self.get_issue(int(query)))
elif mode == "comment_on_issue":
return self.comment_on_issue(query)
elif mode == "create_file":
return self.create_file(query)
elif mode == "create_pull_request":
return self.create_pull_request(query)
elif mode == "read_file":
return self.read_file(query)
elif mode == "update_file":
return self.update_file(query)
elif mode == "delete_file":
return self.delete_file(query)
elif mode == "create_branch":
return self.create_branch(query)
elif mode == "list_branches_in_repo":
return self.list_branches_in_repo()
elif mode == "set_active_branch":
return self.set_active_branch(query)
elif mode == "list_files_in_main_branch":
return self.list_files_in_main_branch()
elif mode == "list_files_in_bot_branch":
return self.list_files_in_bot_branch()
elif mode == "list_files_from_directory":
return self.list_files_from_directory(query)
else:
raise ValueError("Invalid mode" + mode)

Some files were not shown because too many files have changed in this diff Show More