initial commit
This commit is contained in:
323
venv/Lib/site-packages/langchain_community/utilities/__init__.py
Normal file
323
venv/Lib/site-packages/langchain_community/utilities/__init__.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""**Utilities** are the integrations with third-part systems and packages.
|
||||
|
||||
Other LangChain classes use **Utilities** to interact with third-part systems
|
||||
and packages.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.utilities.alpha_vantage import (
|
||||
AlphaVantageAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.apify import (
|
||||
ApifyWrapper,
|
||||
)
|
||||
from langchain_community.utilities.arcee import (
|
||||
ArceeWrapper,
|
||||
)
|
||||
from langchain_community.utilities.arxiv import (
|
||||
ArxivAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.asknews import (
|
||||
AskNewsAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.awslambda import (
|
||||
LambdaWrapper,
|
||||
)
|
||||
from langchain_community.utilities.bibtex import (
|
||||
BibtexparserWrapper,
|
||||
)
|
||||
from langchain_community.utilities.bing_search import (
|
||||
BingSearchAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.brave_search import (
|
||||
BraveSearchWrapper,
|
||||
)
|
||||
from langchain_community.utilities.dataherald import DataheraldAPIWrapper
|
||||
from langchain_community.utilities.dria_index import (
|
||||
DriaAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.duckduckgo_search import (
|
||||
DuckDuckGoSearchAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.golden_query import (
|
||||
GoldenQueryAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.google_books import (
|
||||
GoogleBooksAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.google_finance import (
|
||||
GoogleFinanceAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.google_jobs import (
|
||||
GoogleJobsAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.google_lens import (
|
||||
GoogleLensAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.google_places_api import (
|
||||
GooglePlacesAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.google_scholar import (
|
||||
GoogleScholarAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.google_search import (
|
||||
GoogleSearchAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.google_serper import (
|
||||
GoogleSerperAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.google_trends import (
|
||||
GoogleTrendsAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.graphql import (
|
||||
GraphQLAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.infobip import (
|
||||
InfobipAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.jira import (
|
||||
JiraAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.max_compute import (
|
||||
MaxComputeAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.merriam_webster import (
|
||||
MerriamWebsterAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.metaphor_search import (
|
||||
MetaphorSearchAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.mojeek_search import (
|
||||
MojeekSearchAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.nasa import (
|
||||
NasaAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.nvidia_riva import (
|
||||
AudioStream,
|
||||
NVIDIARivaASR,
|
||||
NVIDIARivaStream,
|
||||
NVIDIARivaTTS,
|
||||
RivaASR,
|
||||
RivaTTS,
|
||||
)
|
||||
from langchain_community.utilities.openweathermap import (
|
||||
OpenWeatherMapAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.oracleai import (
|
||||
OracleSummary,
|
||||
)
|
||||
from langchain_community.utilities.outline import (
|
||||
OutlineAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.passio_nutrition_ai import (
|
||||
NutritionAIAPI,
|
||||
)
|
||||
from langchain_community.utilities.portkey import (
|
||||
Portkey,
|
||||
)
|
||||
from langchain_community.utilities.powerbi import (
|
||||
PowerBIDataset,
|
||||
)
|
||||
from langchain_community.utilities.pubmed import (
|
||||
PubMedAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.rememberizer import RememberizerAPIWrapper
|
||||
from langchain_community.utilities.requests import (
|
||||
Requests,
|
||||
RequestsWrapper,
|
||||
TextRequestsWrapper,
|
||||
)
|
||||
from langchain_community.utilities.scenexplain import (
|
||||
SceneXplainAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.searchapi import (
|
||||
SearchApiAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.searx_search import (
|
||||
SearxSearchWrapper,
|
||||
)
|
||||
from langchain_community.utilities.serpapi import (
|
||||
SerpAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.spark_sql import (
|
||||
SparkSQL,
|
||||
)
|
||||
from langchain_community.utilities.sql_database import (
|
||||
SQLDatabase,
|
||||
)
|
||||
from langchain_community.utilities.stackexchange import (
|
||||
StackExchangeAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.steam import (
|
||||
SteamWebAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.tensorflow_datasets import (
|
||||
TensorflowDatasets,
|
||||
)
|
||||
from langchain_community.utilities.twilio import (
|
||||
TwilioAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.wikipedia import (
|
||||
WikipediaAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.wolfram_alpha import (
|
||||
WolframAlphaAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.you import (
|
||||
YouSearchAPIWrapper,
|
||||
)
|
||||
from langchain_community.utilities.zapier import (
|
||||
ZapierNLAWrapper,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AlphaVantageAPIWrapper",
|
||||
"ApifyWrapper",
|
||||
"ArceeWrapper",
|
||||
"ArxivAPIWrapper",
|
||||
"AskNewsAPIWrapper",
|
||||
"AudioStream",
|
||||
"BibtexparserWrapper",
|
||||
"BingSearchAPIWrapper",
|
||||
"BraveSearchWrapper",
|
||||
"DataheraldAPIWrapper",
|
||||
"DriaAPIWrapper",
|
||||
"DuckDuckGoSearchAPIWrapper",
|
||||
"GoldenQueryAPIWrapper",
|
||||
"GoogleBooksAPIWrapper",
|
||||
"GoogleFinanceAPIWrapper",
|
||||
"GoogleJobsAPIWrapper",
|
||||
"GoogleLensAPIWrapper",
|
||||
"GooglePlacesAPIWrapper",
|
||||
"GoogleScholarAPIWrapper",
|
||||
"GoogleSearchAPIWrapper",
|
||||
"GoogleSerperAPIWrapper",
|
||||
"GoogleTrendsAPIWrapper",
|
||||
"GraphQLAPIWrapper",
|
||||
"InfobipAPIWrapper",
|
||||
"JiraAPIWrapper",
|
||||
"LambdaWrapper",
|
||||
"MaxComputeAPIWrapper",
|
||||
"MerriamWebsterAPIWrapper",
|
||||
"MetaphorSearchAPIWrapper",
|
||||
"MojeekSearchAPIWrapper",
|
||||
"NVIDIARivaASR",
|
||||
"NVIDIARivaStream",
|
||||
"NVIDIARivaTTS",
|
||||
"NasaAPIWrapper",
|
||||
"NutritionAIAPI",
|
||||
"OpenWeatherMapAPIWrapper",
|
||||
"OracleSummary",
|
||||
"OutlineAPIWrapper",
|
||||
"Portkey",
|
||||
"PowerBIDataset",
|
||||
"PubMedAPIWrapper",
|
||||
"RememberizerAPIWrapper",
|
||||
"Requests",
|
||||
"RequestsWrapper",
|
||||
"RivaASR",
|
||||
"RivaTTS",
|
||||
"SceneXplainAPIWrapper",
|
||||
"SearchApiAPIWrapper",
|
||||
"SQLDatabase",
|
||||
"SearxSearchWrapper",
|
||||
"SerpAPIWrapper",
|
||||
"SparkSQL",
|
||||
"StackExchangeAPIWrapper",
|
||||
"SteamWebAPIWrapper",
|
||||
"TensorflowDatasets",
|
||||
"TextRequestsWrapper",
|
||||
"TwilioAPIWrapper",
|
||||
"WikipediaAPIWrapper",
|
||||
"WolframAlphaAPIWrapper",
|
||||
"YouSearchAPIWrapper",
|
||||
"ZapierNLAWrapper",
|
||||
]
|
||||
|
||||
_module_lookup = {
|
||||
"AlphaVantageAPIWrapper": "langchain_community.utilities.alpha_vantage",
|
||||
"ApifyWrapper": "langchain_community.utilities.apify",
|
||||
"ArceeWrapper": "langchain_community.utilities.arcee",
|
||||
"ArxivAPIWrapper": "langchain_community.utilities.arxiv",
|
||||
"AskNewsAPIWrapper": "langchain_community.utilities.asknews",
|
||||
"AudioStream": "langchain_community.utilities.nvidia_riva",
|
||||
"BibtexparserWrapper": "langchain_community.utilities.bibtex",
|
||||
"BingSearchAPIWrapper": "langchain_community.utilities.bing_search",
|
||||
"BraveSearchWrapper": "langchain_community.utilities.brave_search",
|
||||
"DataheraldAPIWrapper": "langchain_community.utilities.dataherald",
|
||||
"DriaAPIWrapper": "langchain_community.utilities.dria_index",
|
||||
"DuckDuckGoSearchAPIWrapper": "langchain_community.utilities.duckduckgo_search",
|
||||
"GoldenQueryAPIWrapper": "langchain_community.utilities.golden_query",
|
||||
"GoogleBooksAPIWrapper": "langchain_community.utilities.google_books",
|
||||
"GoogleFinanceAPIWrapper": "langchain_community.utilities.google_finance",
|
||||
"GoogleJobsAPIWrapper": "langchain_community.utilities.google_jobs",
|
||||
"GoogleLensAPIWrapper": "langchain_community.utilities.google_lens",
|
||||
"GooglePlacesAPIWrapper": "langchain_community.utilities.google_places_api",
|
||||
"GoogleScholarAPIWrapper": "langchain_community.utilities.google_scholar",
|
||||
"GoogleSearchAPIWrapper": "langchain_community.utilities.google_search",
|
||||
"GoogleSerperAPIWrapper": "langchain_community.utilities.google_serper",
|
||||
"GoogleTrendsAPIWrapper": "langchain_community.utilities.google_trends",
|
||||
"GraphQLAPIWrapper": "langchain_community.utilities.graphql",
|
||||
"InfobipAPIWrapper": "langchain_community.utilities.infobip",
|
||||
"JiraAPIWrapper": "langchain_community.utilities.jira",
|
||||
"LambdaWrapper": "langchain_community.utilities.awslambda",
|
||||
"MaxComputeAPIWrapper": "langchain_community.utilities.max_compute",
|
||||
"MerriamWebsterAPIWrapper": "langchain_community.utilities.merriam_webster",
|
||||
"MetaphorSearchAPIWrapper": "langchain_community.utilities.metaphor_search",
|
||||
"MojeekSearchAPIWrapper": "langchain_community.utilities.mojeek_search",
|
||||
"NVIDIARivaASR": "langchain_community.utilities.nvidia_riva",
|
||||
"NVIDIARivaStream": "langchain_community.utilities.nvidia_riva",
|
||||
"NVIDIARivaTTS": "langchain_community.utilities.nvidia_riva",
|
||||
"NasaAPIWrapper": "langchain_community.utilities.nasa",
|
||||
"NutritionAIAPI": "langchain_community.utilities.passio_nutrition_ai",
|
||||
"OpenWeatherMapAPIWrapper": "langchain_community.utilities.openweathermap",
|
||||
"OracleSummary": "langchain_community.utilities.oracleai",
|
||||
"OutlineAPIWrapper": "langchain_community.utilities.outline",
|
||||
"Portkey": "langchain_community.utilities.portkey",
|
||||
"PowerBIDataset": "langchain_community.utilities.powerbi",
|
||||
"PubMedAPIWrapper": "langchain_community.utilities.pubmed",
|
||||
"RememberizerAPIWrapper": "langchain_community.utilities.rememberizer",
|
||||
"Requests": "langchain_community.utilities.requests",
|
||||
"RequestsWrapper": "langchain_community.utilities.requests",
|
||||
"RivaASR": "langchain_community.utilities.nvidia_riva",
|
||||
"RivaTTS": "langchain_community.utilities.nvidia_riva",
|
||||
"SQLDatabase": "langchain_community.utilities.sql_database",
|
||||
"SceneXplainAPIWrapper": "langchain_community.utilities.scenexplain",
|
||||
"SearchApiAPIWrapper": "langchain_community.utilities.searchapi",
|
||||
"SearxSearchWrapper": "langchain_community.utilities.searx_search",
|
||||
"SerpAPIWrapper": "langchain_community.utilities.serpapi",
|
||||
"SparkSQL": "langchain_community.utilities.spark_sql",
|
||||
"StackExchangeAPIWrapper": "langchain_community.utilities.stackexchange",
|
||||
"SteamWebAPIWrapper": "langchain_community.utilities.steam",
|
||||
"TensorflowDatasets": "langchain_community.utilities.tensorflow_datasets",
|
||||
"TextRequestsWrapper": "langchain_community.utilities.requests",
|
||||
"TwilioAPIWrapper": "langchain_community.utilities.twilio",
|
||||
"WikipediaAPIWrapper": "langchain_community.utilities.wikipedia",
|
||||
"WolframAlphaAPIWrapper": "langchain_community.utilities.wolfram_alpha",
|
||||
"YouSearchAPIWrapper": "langchain_community.utilities.you",
|
||||
"ZapierNLAWrapper": "langchain_community.utilities.zapier",
|
||||
}
|
||||
|
||||
REMOVED = {
|
||||
"PythonREPL": (
|
||||
"PythonREPL has been deprecated from langchain_community "
|
||||
"due to being flagged by security scanners. See: "
|
||||
"https://github.com/langchain-ai/langchain/issues/14345 "
|
||||
"If you need to use it, please use the version "
|
||||
"from langchain_experimental. "
|
||||
"from langchain_experimental.utilities.python import PythonREPL."
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in REMOVED:
|
||||
raise AssertionError(REMOVED[name])
|
||||
if name in _module_lookup:
|
||||
module = importlib.import_module(_module_lookup[name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,176 @@
|
||||
"""Util that calls AlphaVantage for Currency Exchange Rate."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
|
||||
class AlphaVantageAPIWrapper(BaseModel):
|
||||
"""Wrapper for AlphaVantage API for Currency Exchange Rate.
|
||||
|
||||
Docs for using:
|
||||
|
||||
1. Go to AlphaVantage and sign up for an API key
|
||||
2. Save your API KEY into ALPHAVANTAGE_API_KEY env variable
|
||||
"""
|
||||
|
||||
alphavantage_api_key: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["alphavantage_api_key"] = get_from_dict_or_env(
|
||||
values, "alphavantage_api_key", "ALPHAVANTAGE_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
def search_symbols(self, keywords: str) -> Dict[str, Any]:
|
||||
"""Make a request to the AlphaVantage API to search for symbols."""
|
||||
response = requests.get(
|
||||
"https://www.alphavantage.co/query/",
|
||||
params={
|
||||
"function": "SYMBOL_SEARCH",
|
||||
"keywords": keywords,
|
||||
"apikey": self.alphavantage_api_key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if "Error Message" in data:
|
||||
raise ValueError(f"API Error: {data['Error Message']}")
|
||||
|
||||
return data
|
||||
|
||||
def _get_market_news_sentiment(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Make a request to the AlphaVantage API to get market news sentiment for a
|
||||
given symbol."""
|
||||
response = requests.get(
|
||||
"https://www.alphavantage.co/query/",
|
||||
params={
|
||||
"function": "NEWS_SENTIMENT",
|
||||
"symbol": symbol,
|
||||
"apikey": self.alphavantage_api_key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if "Error Message" in data:
|
||||
raise ValueError(f"API Error: {data['Error Message']}")
|
||||
|
||||
return data
|
||||
|
||||
def _get_time_series_daily(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Make a request to the AlphaVantage API to get the daily time series."""
|
||||
response = requests.get(
|
||||
"https://www.alphavantage.co/query/",
|
||||
params={
|
||||
"function": "TIME_SERIES_DAILY",
|
||||
"symbol": symbol,
|
||||
"apikey": self.alphavantage_api_key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if "Error Message" in data:
|
||||
raise ValueError(f"API Error: {data['Error Message']}")
|
||||
|
||||
return data
|
||||
|
||||
def _get_quote_endpoint(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Make a request to the AlphaVantage API to get the
|
||||
latest price and volume information."""
|
||||
response = requests.get(
|
||||
"https://www.alphavantage.co/query/",
|
||||
params={
|
||||
"function": "GLOBAL_QUOTE",
|
||||
"symbol": symbol,
|
||||
"apikey": self.alphavantage_api_key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if "Error Message" in data:
|
||||
raise ValueError(f"API Error: {data['Error Message']}")
|
||||
|
||||
return data
|
||||
|
||||
def _get_time_series_weekly(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Make a request to the AlphaVantage API
|
||||
to get the Weekly Time Series."""
|
||||
response = requests.get(
|
||||
"https://www.alphavantage.co/query/",
|
||||
params={
|
||||
"function": "TIME_SERIES_WEEKLY",
|
||||
"symbol": symbol,
|
||||
"apikey": self.alphavantage_api_key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if "Error Message" in data:
|
||||
raise ValueError(f"API Error: {data['Error Message']}")
|
||||
|
||||
return data
|
||||
|
||||
def _get_top_gainers_losers(self) -> Dict[str, Any]:
|
||||
"""Make a request to the AlphaVantage API to get the top gainers, losers,
|
||||
and most actively traded tickers in the US market."""
|
||||
response = requests.get(
|
||||
"https://www.alphavantage.co/query/",
|
||||
params={
|
||||
"function": "TOP_GAINERS_LOSERS",
|
||||
"apikey": self.alphavantage_api_key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if "Error Message" in data:
|
||||
raise ValueError(f"API Error: {data['Error Message']}")
|
||||
|
||||
return data
|
||||
|
||||
def _get_exchange_rate(
|
||||
self, from_currency: str, to_currency: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Make a request to the AlphaVantage API to get the exchange rate."""
|
||||
response = requests.get(
|
||||
"https://www.alphavantage.co/query/",
|
||||
params={
|
||||
"function": "CURRENCY_EXCHANGE_RATE",
|
||||
"from_currency": from_currency,
|
||||
"to_currency": to_currency,
|
||||
"apikey": self.alphavantage_api_key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if "Error Message" in data:
|
||||
raise ValueError(f"API Error: {data['Error Message']}")
|
||||
|
||||
return data
|
||||
|
||||
@property
|
||||
def standard_currencies(self) -> List[str]:
|
||||
return ["USD", "EUR", "GBP", "JPY", "CHF", "CAD", "AUD", "NZD"]
|
||||
|
||||
def run(self, from_currency: str, to_currency: str) -> str:
|
||||
"""Get the current exchange rate for a specified currency pair."""
|
||||
if to_currency not in self.standard_currencies:
|
||||
from_currency, to_currency = to_currency, from_currency
|
||||
|
||||
data = self._get_exchange_rate(from_currency, to_currency)
|
||||
return data["Realtime Currency Exchange Rate"]
|
||||
@@ -0,0 +1,27 @@
|
||||
from typing import Any, List
|
||||
|
||||
|
||||
def _get_anthropic_client() -> Any:
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import anthropic python package. "
|
||||
"This is needed in order to accurately tokenize the text "
|
||||
"for anthropic models. Please install it with `pip install anthropic`."
|
||||
)
|
||||
return anthropic.Anthropic()
|
||||
|
||||
|
||||
def get_num_tokens_anthropic(text: str) -> int:
|
||||
"""Get the number of tokens in a string of text."""
|
||||
client = _get_anthropic_client()
|
||||
return client.count_tokens(text=text)
|
||||
|
||||
|
||||
def get_token_ids_anthropic(text: str) -> List[int]:
|
||||
"""Get the token ids for a string of text."""
|
||||
client = _get_anthropic_client()
|
||||
tokenizer = client.get_tokenizer()
|
||||
encoded_text = tokenizer.encode(text)
|
||||
return encoded_text.ids
|
||||
227
venv/Lib/site-packages/langchain_community/utilities/apify.py
Normal file
227
venv/Lib/site-packages/langchain_community/utilities/apify.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.document_loaders import ApifyDatasetLoader
|
||||
|
||||
|
||||
@deprecated(
|
||||
since="0.3.18",
|
||||
message=(
|
||||
"This class is deprecated and will be removed in a future version. "
|
||||
"You can swap to using the `ApifyWrapper`"
|
||||
" implementation in `langchain_apify` package. "
|
||||
"See <https://github.com/apify/langchain-apify>"
|
||||
),
|
||||
alternative_import="langchain_apify.ApifyWrapper",
|
||||
)
|
||||
class ApifyWrapper(BaseModel):
|
||||
"""Wrapper around Apify.
|
||||
To use, you should have the ``apify-client`` python package installed,
|
||||
and the environment variable ``APIFY_API_TOKEN`` set with your API key, or pass
|
||||
`apify_api_token` as a named parameter to the constructor.
|
||||
"""
|
||||
|
||||
apify_client: Any
|
||||
apify_client_async: Any
|
||||
apify_api_token: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate environment.
|
||||
Validate that an Apify API token is set and the apify-client
|
||||
Python package exists in the current environment.
|
||||
"""
|
||||
apify_api_token = get_from_dict_or_env(
|
||||
values, "apify_api_token", "APIFY_API_TOKEN"
|
||||
)
|
||||
|
||||
try:
|
||||
from apify_client import ApifyClient, ApifyClientAsync
|
||||
|
||||
client = ApifyClient(apify_api_token)
|
||||
if httpx_client := getattr(client.http_client, "httpx_client"):
|
||||
httpx_client.headers["user-agent"] += "; Origin/langchain"
|
||||
|
||||
async_client = ApifyClientAsync(apify_api_token)
|
||||
if httpx_async_client := getattr(
|
||||
async_client.http_client, "httpx_async_client"
|
||||
):
|
||||
httpx_async_client.headers["user-agent"] += "; Origin/langchain"
|
||||
|
||||
values["apify_client"] = client
|
||||
values["apify_client_async"] = async_client
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import apify-client Python package. "
|
||||
"Please install it with `pip install apify-client`."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def call_actor(
|
||||
self,
|
||||
actor_id: str,
|
||||
run_input: Dict,
|
||||
dataset_mapping_function: Callable[[Dict], Document],
|
||||
*,
|
||||
build: Optional[str] = None,
|
||||
memory_mbytes: Optional[int] = None,
|
||||
timeout_secs: Optional[int] = None,
|
||||
) -> "ApifyDatasetLoader":
|
||||
"""Run an Actor on the Apify platform and wait for results to be ready.
|
||||
Args:
|
||||
actor_id (str): The ID or name of the Actor on the Apify platform.
|
||||
run_input (Dict): The input object of the Actor that you're trying to run.
|
||||
dataset_mapping_function (Callable): A function that takes a single
|
||||
dictionary (an Apify dataset item) and converts it to an
|
||||
instance of the Document class.
|
||||
build (str, optional): Optionally specifies the actor build to run.
|
||||
It can be either a build tag or build number.
|
||||
memory_mbytes (int, optional): Optional memory limit for the run,
|
||||
in megabytes.
|
||||
timeout_secs (int, optional): Optional timeout for the run, in seconds.
|
||||
Returns:
|
||||
ApifyDatasetLoader: A loader that will fetch the records from the
|
||||
Actor run's default dataset.
|
||||
"""
|
||||
from langchain_community.document_loaders import ApifyDatasetLoader
|
||||
|
||||
actor_call = self.apify_client.actor(actor_id).call(
|
||||
run_input=run_input,
|
||||
build=build,
|
||||
memory_mbytes=memory_mbytes,
|
||||
timeout_secs=timeout_secs,
|
||||
)
|
||||
|
||||
return ApifyDatasetLoader(
|
||||
dataset_id=actor_call["defaultDatasetId"],
|
||||
dataset_mapping_function=dataset_mapping_function,
|
||||
)
|
||||
|
||||
async def acall_actor(
|
||||
self,
|
||||
actor_id: str,
|
||||
run_input: Dict,
|
||||
dataset_mapping_function: Callable[[Dict], Document],
|
||||
*,
|
||||
build: Optional[str] = None,
|
||||
memory_mbytes: Optional[int] = None,
|
||||
timeout_secs: Optional[int] = None,
|
||||
) -> "ApifyDatasetLoader":
|
||||
"""Run an Actor on the Apify platform and wait for results to be ready.
|
||||
Args:
|
||||
actor_id (str): The ID or name of the Actor on the Apify platform.
|
||||
run_input (Dict): The input object of the Actor that you're trying to run.
|
||||
dataset_mapping_function (Callable): A function that takes a single
|
||||
dictionary (an Apify dataset item) and converts it to
|
||||
an instance of the Document class.
|
||||
build (str, optional): Optionally specifies the actor build to run.
|
||||
It can be either a build tag or build number.
|
||||
memory_mbytes (int, optional): Optional memory limit for the run,
|
||||
in megabytes.
|
||||
timeout_secs (int, optional): Optional timeout for the run, in seconds.
|
||||
Returns:
|
||||
ApifyDatasetLoader: A loader that will fetch the records from the
|
||||
Actor run's default dataset.
|
||||
"""
|
||||
from langchain_community.document_loaders import ApifyDatasetLoader
|
||||
|
||||
actor_call = await self.apify_client_async.actor(actor_id).call(
|
||||
run_input=run_input,
|
||||
build=build,
|
||||
memory_mbytes=memory_mbytes,
|
||||
timeout_secs=timeout_secs,
|
||||
)
|
||||
|
||||
return ApifyDatasetLoader(
|
||||
dataset_id=actor_call["defaultDatasetId"],
|
||||
dataset_mapping_function=dataset_mapping_function,
|
||||
)
|
||||
|
||||
def call_actor_task(
|
||||
self,
|
||||
task_id: str,
|
||||
task_input: Dict,
|
||||
dataset_mapping_function: Callable[[Dict], Document],
|
||||
*,
|
||||
build: Optional[str] = None,
|
||||
memory_mbytes: Optional[int] = None,
|
||||
timeout_secs: Optional[int] = None,
|
||||
) -> "ApifyDatasetLoader":
|
||||
"""Run a saved Actor task on Apify and wait for results to be ready.
|
||||
Args:
|
||||
task_id (str): The ID or name of the task on the Apify platform.
|
||||
task_input (Dict): The input object of the task that you're trying to run.
|
||||
Overrides the task's saved input.
|
||||
dataset_mapping_function (Callable): A function that takes a single
|
||||
dictionary (an Apify dataset item) and converts it to an
|
||||
instance of the Document class.
|
||||
build (str, optional): Optionally specifies the actor build to run.
|
||||
It can be either a build tag or build number.
|
||||
memory_mbytes (int, optional): Optional memory limit for the run,
|
||||
in megabytes.
|
||||
timeout_secs (int, optional): Optional timeout for the run, in seconds.
|
||||
Returns:
|
||||
ApifyDatasetLoader: A loader that will fetch the records from the
|
||||
task run's default dataset.
|
||||
"""
|
||||
from langchain_community.document_loaders import ApifyDatasetLoader
|
||||
|
||||
task_call = self.apify_client.task(task_id).call(
|
||||
task_input=task_input,
|
||||
build=build,
|
||||
memory_mbytes=memory_mbytes,
|
||||
timeout_secs=timeout_secs,
|
||||
)
|
||||
|
||||
return ApifyDatasetLoader(
|
||||
dataset_id=task_call["defaultDatasetId"],
|
||||
dataset_mapping_function=dataset_mapping_function,
|
||||
)
|
||||
|
||||
async def acall_actor_task(
|
||||
self,
|
||||
task_id: str,
|
||||
task_input: Dict,
|
||||
dataset_mapping_function: Callable[[Dict], Document],
|
||||
*,
|
||||
build: Optional[str] = None,
|
||||
memory_mbytes: Optional[int] = None,
|
||||
timeout_secs: Optional[int] = None,
|
||||
) -> "ApifyDatasetLoader":
|
||||
"""Run a saved Actor task on Apify and wait for results to be ready.
|
||||
Args:
|
||||
task_id (str): The ID or name of the task on the Apify platform.
|
||||
task_input (Dict): The input object of the task that you're trying to run.
|
||||
Overrides the task's saved input.
|
||||
dataset_mapping_function (Callable): A function that takes a single
|
||||
dictionary (an Apify dataset item) and converts it to an
|
||||
instance of the Document class.
|
||||
build (str, optional): Optionally specifies the actor build to run.
|
||||
It can be either a build tag or build number.
|
||||
memory_mbytes (int, optional): Optional memory limit for the run,
|
||||
in megabytes.
|
||||
timeout_secs (int, optional): Optional timeout for the run, in seconds.
|
||||
Returns:
|
||||
ApifyDatasetLoader: A loader that will fetch the records from the
|
||||
task run's default dataset.
|
||||
"""
|
||||
from langchain_community.document_loaders import ApifyDatasetLoader
|
||||
|
||||
task_call = await self.apify_client_async.task(task_id).call(
|
||||
task_input=task_input,
|
||||
build=build,
|
||||
memory_mbytes=memory_mbytes,
|
||||
timeout_secs=timeout_secs,
|
||||
)
|
||||
|
||||
return ApifyDatasetLoader(
|
||||
dataset_id=task_call["defaultDatasetId"],
|
||||
dataset_mapping_function=dataset_mapping_function,
|
||||
)
|
||||
256
venv/Lib/site-packages/langchain_community/utilities/arcee.py
Normal file
256
venv/Lib/site-packages/langchain_community/utilities/arcee.py
Normal file
@@ -0,0 +1,256 @@
|
||||
# This module contains utility classes and functions for interacting with Arcee API.
|
||||
# For more information and updates, refer to the Arcee utils page:
|
||||
# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
|
||||
|
||||
import requests
|
||||
from langchain_core.retrievers import Document
|
||||
from pydantic import BaseModel, SecretStr, model_validator
|
||||
|
||||
|
||||
class ArceeRoute(str, Enum):
|
||||
"""Routes available for the Arcee API as enumerator."""
|
||||
|
||||
generate = "models/generate"
|
||||
retrieve = "models/retrieve"
|
||||
model_training_status = "models/status/{id_or_name}"
|
||||
|
||||
|
||||
class DALMFilterType(str, Enum):
|
||||
"""Filter types available for a DALM retrieval as enumerator."""
|
||||
|
||||
fuzzy_search = "fuzzy_search"
|
||||
strict_search = "strict_search"
|
||||
|
||||
|
||||
class DALMFilter(BaseModel):
|
||||
"""Filters available for a DALM retrieval and generation.
|
||||
|
||||
Arguments:
|
||||
field_name: The field to filter on. Can be 'document' or 'name' to filter
|
||||
on your document's raw text or title. Any other field will be presumed
|
||||
to be a metadata field you included when uploading your context data
|
||||
filter_type: Currently 'fuzzy_search' and 'strict_search' are supported.
|
||||
'fuzzy_search' means a fuzzy search on the provided field is performed.
|
||||
The exact strict doesn't need to exist in the document
|
||||
for this to find a match.
|
||||
Very useful for scanning a document for some keyword terms.
|
||||
'strict_search' means that the exact string must appear
|
||||
in the provided field.
|
||||
This is NOT an exact eq filter. ie a document with content
|
||||
"the happy dog crossed the street" will match on a strict_search of
|
||||
"dog" but won't match on "the dog".
|
||||
Python equivalent of `return search_string in full_string`.
|
||||
value: The actual value to search for in the context data/metadata
|
||||
"""
|
||||
|
||||
field_name: str
|
||||
filter_type: DALMFilterType
|
||||
value: str
|
||||
_is_metadata: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_meta(cls, values: Dict) -> Any:
|
||||
"""document and name are reserved arcee keys. Anything else is metadata"""
|
||||
values["_is_meta"] = values.get("field_name") not in ["document", "name"]
|
||||
return values
|
||||
|
||||
|
||||
class ArceeDocumentSource(BaseModel):
|
||||
"""Source of an Arcee document."""
|
||||
|
||||
document: str
|
||||
name: str
|
||||
id: str
|
||||
|
||||
|
||||
class ArceeDocument(BaseModel):
|
||||
"""Arcee document."""
|
||||
|
||||
index: str
|
||||
id: str
|
||||
score: float
|
||||
source: ArceeDocumentSource
|
||||
|
||||
|
||||
class ArceeDocumentAdapter:
|
||||
"""Adapter for Arcee documents"""
|
||||
|
||||
@classmethod
|
||||
def adapt(cls, arcee_document: ArceeDocument) -> Document:
|
||||
"""Adapts an `ArceeDocument` to a langchain's `Document` object."""
|
||||
return Document(
|
||||
page_content=arcee_document.source.document,
|
||||
metadata={
|
||||
# arcee document; source metadata
|
||||
"name": arcee_document.source.name,
|
||||
"source_id": arcee_document.source.id,
|
||||
# arcee document metadata
|
||||
"index": arcee_document.index,
|
||||
"id": arcee_document.id,
|
||||
"score": arcee_document.score,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ArceeWrapper:
|
||||
"""Wrapper for Arcee API.
|
||||
|
||||
For more details, see: https://www.arcee.ai/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arcee_api_key: Union[str, SecretStr],
|
||||
arcee_api_url: str,
|
||||
arcee_api_version: str,
|
||||
model_kwargs: Optional[Dict[str, Any]],
|
||||
model_name: str,
|
||||
):
|
||||
"""Initialize ArceeWrapper.
|
||||
|
||||
Arguments:
|
||||
arcee_api_key: API key for Arcee API.
|
||||
arcee_api_url: URL for Arcee API.
|
||||
arcee_api_version: Version of Arcee API.
|
||||
model_kwargs: Keyword arguments for Arcee API.
|
||||
model_name: Name of an Arcee model.
|
||||
"""
|
||||
if isinstance(arcee_api_key, str):
|
||||
arcee_api_key_ = SecretStr(arcee_api_key)
|
||||
else:
|
||||
arcee_api_key_ = arcee_api_key
|
||||
self.arcee_api_key: SecretStr = arcee_api_key_
|
||||
self.model_kwargs = model_kwargs
|
||||
self.arcee_api_url = arcee_api_url
|
||||
self.arcee_api_version = arcee_api_version
|
||||
|
||||
try:
|
||||
route = ArceeRoute.model_training_status.value.format(id_or_name=model_name)
|
||||
response = self._make_request("get", route)
|
||||
self.model_id = response.get("model_id")
|
||||
self.model_training_status = response.get("status")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error while validating model training status for '{model_name}': {e}"
|
||||
) from e
|
||||
|
||||
def validate_model_training_status(self) -> None:
|
||||
if self.model_training_status != "training_complete":
|
||||
raise Exception(
|
||||
f"Model {self.model_id} is not ready. "
|
||||
"Please wait for training to complete."
|
||||
)
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
method: Literal["post", "get"],
|
||||
route: Union[ArceeRoute, str],
|
||||
body: Optional[Mapping[str, Any]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Make a request to the Arcee API
|
||||
Args:
|
||||
method: The HTTP method to use
|
||||
route: The route to call
|
||||
body: The body of the request
|
||||
params: The query params of the request
|
||||
headers: The headers of the request
|
||||
"""
|
||||
headers = self._make_request_headers(headers=headers)
|
||||
url = self._make_request_url(route=route)
|
||||
|
||||
req_type = getattr(requests, method)
|
||||
|
||||
response = req_type(url, json=body, params=params, headers=headers)
|
||||
if response.status_code not in (200, 201):
|
||||
raise Exception(f"Failed to make request. Response: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
|
||||
headers = headers or {}
|
||||
if not isinstance(self.arcee_api_key, SecretStr):
|
||||
raise TypeError(
|
||||
f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}"
|
||||
)
|
||||
api_key = self.arcee_api_key.get_secret_value()
|
||||
internal_headers = {
|
||||
"X-Token": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
headers.update(internal_headers)
|
||||
return headers
|
||||
|
||||
def _make_request_url(self, route: Union[ArceeRoute, str]) -> str:
|
||||
return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}"
|
||||
|
||||
def _make_request_body_for_models(
|
||||
self, prompt: str, **kwargs: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
"""Make the request body for generate/retrieve models endpoint"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_params = {**_model_kwargs, **kwargs}
|
||||
|
||||
filters = [DALMFilter(**f) for f in _params.get("filters", [])]
|
||||
return dict(
|
||||
model_id=self.model_id,
|
||||
query=prompt,
|
||||
size=_params.get("size", 3),
|
||||
filters=filters,
|
||||
id=self.model_id,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate text from Arcee DALM.
|
||||
|
||||
Args:
|
||||
prompt: Prompt to generate text from.
|
||||
size: The max number of context results to retrieve. Defaults to 3.
|
||||
(Can be less if filters are provided).
|
||||
filters: Filters to apply to the context dataset.
|
||||
"""
|
||||
|
||||
response = self._make_request(
|
||||
method="post",
|
||||
route=ArceeRoute.generate.value,
|
||||
body=self._make_request_body_for_models(
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
return response["text"]
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Retrieve {size} contexts with your retriever for a given query
|
||||
|
||||
Args:
|
||||
query: Query to submit to the model
|
||||
size: The max number of context results to retrieve. Defaults to 3.
|
||||
(Can be less if filters are provided).
|
||||
filters: Filters to apply to the context dataset.
|
||||
"""
|
||||
|
||||
response = self._make_request(
|
||||
method="post",
|
||||
route=ArceeRoute.retrieve.value,
|
||||
body=self._make_request_body_for_models(
|
||||
prompt=query,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
return [
|
||||
ArceeDocumentAdapter.adapt(ArceeDocument(**doc))
|
||||
for doc in response["results"]
|
||||
]
|
||||
255
venv/Lib/site-packages/langchain_community/utilities/arxiv.py
Normal file
255
venv/Lib/site-packages/langchain_community/utilities/arxiv.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Util that calls Arxiv."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArxivAPIWrapper(BaseModel):
|
||||
"""Wrapper around ArxivAPI.
|
||||
|
||||
To use, you should have the ``arxiv`` python package installed.
|
||||
https://lukasschwab.me/arxiv.py/index.html
|
||||
This wrapper will use the Arxiv API to conduct searches and
|
||||
fetch document summaries. By default, it will return the document summaries
|
||||
of the top-k results.
|
||||
If the query is in the form of arxiv identifier
|
||||
(see https://info.arxiv.org/help/find/index.html), it will return the paper
|
||||
corresponding to the arxiv identifier.
|
||||
It limits the Document content by doc_content_chars_max.
|
||||
Set doc_content_chars_max=None if you don't want to limit the content size.
|
||||
|
||||
Attributes:
|
||||
top_k_results: number of the top-scored document used for the arxiv tool
|
||||
ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
|
||||
continue_on_failure (bool): If True, continue loading other URLs on failure.
|
||||
load_max_docs: a limit to the number of loaded documents
|
||||
load_all_available_meta:
|
||||
if True: the `metadata` of the loaded Documents contains all available
|
||||
meta info (see https://lukasschwab.me/arxiv.py/index.html#Result),
|
||||
if False: the `metadata` contains only the published date, title,
|
||||
authors and summary.
|
||||
doc_content_chars_max: an optional cut limit for the length of a document's
|
||||
content
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.utilities.arxiv import ArxivAPIWrapper
|
||||
arxiv = ArxivAPIWrapper(
|
||||
top_k_results = 3,
|
||||
ARXIV_MAX_QUERY_LENGTH = 300,
|
||||
load_max_docs = 3,
|
||||
load_all_available_meta = False,
|
||||
doc_content_chars_max = 40000
|
||||
)
|
||||
arxiv.run("tree of thought llm")
|
||||
"""
|
||||
|
||||
arxiv_search: Any #: :meta private:
|
||||
arxiv_exceptions: Any # :meta private:
|
||||
top_k_results: int = 3
|
||||
ARXIV_MAX_QUERY_LENGTH: int = 300
|
||||
continue_on_failure: bool = False
|
||||
load_max_docs: int = 100
|
||||
load_all_available_meta: bool = False
|
||||
doc_content_chars_max: Optional[int] = 4000
|
||||
|
||||
def is_arxiv_identifier(self, query: str) -> bool:
|
||||
"""Check if a query is an arxiv identifier."""
|
||||
arxiv_identifier_pattern = r"\d{2}(0[1-9]|1[0-2])\.\d{4,5}(v\d+|)|\d{7}.*"
|
||||
for query_item in query[: self.ARXIV_MAX_QUERY_LENGTH].split():
|
||||
match_result = re.match(arxiv_identifier_pattern, query_item)
|
||||
if not match_result:
|
||||
return False
|
||||
assert match_result is not None
|
||||
if not match_result.group(0) == query_item:
|
||||
return False
|
||||
return True
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that the python package exists in environment."""
|
||||
try:
|
||||
import arxiv
|
||||
|
||||
values["arxiv_search"] = arxiv.Search
|
||||
values["arxiv_exceptions"] = (
|
||||
arxiv.ArxivError,
|
||||
arxiv.UnexpectedEmptyPageError,
|
||||
arxiv.HTTPError,
|
||||
)
|
||||
values["arxiv_result"] = arxiv.Result
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import arxiv python package. "
|
||||
"Please install it with `pip install arxiv`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _fetch_results(self, query: str) -> Any:
|
||||
"""Helper function to fetch arxiv results based on query."""
|
||||
if self.is_arxiv_identifier(query):
|
||||
return self.arxiv_search(
|
||||
id_list=query.split(), max_results=self.top_k_results
|
||||
).results()
|
||||
return self.arxiv_search(
|
||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
|
||||
).results()
|
||||
|
||||
def get_summaries_as_docs(self, query: str) -> List[Document]:
|
||||
"""
|
||||
Performs an arxiv search and returns list of
|
||||
documents, with summaries as the content.
|
||||
|
||||
If an error occurs or no documents found, error text
|
||||
is returned instead. Wrapper for
|
||||
https://lukasschwab.me/arxiv.py/index.html#Search
|
||||
|
||||
Args:
|
||||
query: a plaintext search query
|
||||
"""
|
||||
try:
|
||||
results = self._fetch_results(
|
||||
query
|
||||
) # Using helper function to fetch results
|
||||
except self.arxiv_exceptions as ex:
|
||||
logger.error(f"Arxiv exception: {ex}") # Added error logging
|
||||
return [Document(page_content=f"Arxiv exception: {ex}")]
|
||||
docs = [
|
||||
Document(
|
||||
page_content=result.summary,
|
||||
metadata={
|
||||
"Entry ID": result.entry_id,
|
||||
"Published": result.updated.date(),
|
||||
"Title": result.title,
|
||||
"Authors": ", ".join(a.name for a in result.authors),
|
||||
},
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
return docs
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""
|
||||
Performs an arxiv search and A single string
|
||||
with the publish date, title, authors, and summary
|
||||
for each article separated by two newlines.
|
||||
|
||||
If an error occurs or no documents found, error text
|
||||
is returned instead. Wrapper for
|
||||
https://lukasschwab.me/arxiv.py/index.html#Search
|
||||
|
||||
Args:
|
||||
query: a plaintext search query
|
||||
"""
|
||||
try:
|
||||
results = self._fetch_results(
|
||||
query
|
||||
) # Using helper function to fetch results
|
||||
except self.arxiv_exceptions as ex:
|
||||
logger.error(f"Arxiv exception: {ex}") # Added error logging
|
||||
return f"Arxiv exception: {ex}"
|
||||
docs = [
|
||||
f"Published: {result.updated.date()}\n"
|
||||
f"Title: {result.title}\n"
|
||||
f"Authors: {', '.join(a.name for a in result.authors)}\n"
|
||||
f"Summary: {result.summary}"
|
||||
for result in results
|
||||
]
|
||||
if docs:
|
||||
return "\n\n".join(docs)[: self.doc_content_chars_max]
|
||||
else:
|
||||
return "No good Arxiv Result was found"
|
||||
|
||||
def load(self, query: str) -> List[Document]:
|
||||
"""
|
||||
Run Arxiv search and get the article texts plus the article meta information.
|
||||
See https://lukasschwab.me/arxiv.py/index.html#Search
|
||||
|
||||
Returns: a list of documents with the document.page_content in text format
|
||||
|
||||
Performs an arxiv search, downloads the top k results as PDFs, loads
|
||||
them as Documents, and returns them in a List.
|
||||
|
||||
Args:
|
||||
query: a plaintext search query
|
||||
"""
|
||||
return list(self.lazy_load(query))
|
||||
|
||||
def lazy_load(self, query: str) -> Iterator[Document]:
|
||||
"""
|
||||
Run Arxiv search and get the article texts plus the article meta information.
|
||||
See https://lukasschwab.me/arxiv.py/index.html#Search
|
||||
|
||||
Returns: documents with the document.page_content in text format
|
||||
|
||||
Performs an arxiv search, downloads the top k results as PDFs, loads
|
||||
them as Documents, and returns them.
|
||||
|
||||
Args:
|
||||
query: a plaintext search query
|
||||
"""
|
||||
try:
|
||||
import fitz
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PyMuPDF package not found, please install it with "
|
||||
"`pip install pymupdf`"
|
||||
)
|
||||
|
||||
try:
|
||||
# Remove the ":" and "-" from the query, as they can cause search problems
|
||||
query = query.replace(":", "").replace("-", "")
|
||||
results = self._fetch_results(
|
||||
query
|
||||
) # Using helper function to fetch results
|
||||
except self.arxiv_exceptions as ex:
|
||||
logger.debug("Error on arxiv: %s", ex)
|
||||
return
|
||||
|
||||
for result in results:
|
||||
try:
|
||||
doc_file_name: str = result.download_pdf()
|
||||
with fitz.open(doc_file_name) as doc_file:
|
||||
text: str = "".join(page.get_text() for page in doc_file)
|
||||
except (FileNotFoundError, fitz.fitz.FileDataError) as f_ex:
|
||||
logger.debug(f_ex)
|
||||
continue
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
logger.error(e)
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
if self.load_all_available_meta:
|
||||
extra_metadata = {
|
||||
"entry_id": result.entry_id,
|
||||
"published_first_time": str(result.published.date()),
|
||||
"comment": result.comment,
|
||||
"journal_ref": result.journal_ref,
|
||||
"doi": result.doi,
|
||||
"primary_category": result.primary_category,
|
||||
"categories": result.categories,
|
||||
"links": [link.href for link in result.links],
|
||||
}
|
||||
else:
|
||||
extra_metadata = {}
|
||||
metadata = {
|
||||
"Published": str(result.updated.date()),
|
||||
"Title": result.title,
|
||||
"Authors": ", ".join(a.name for a in result.authors),
|
||||
"Summary": result.summary,
|
||||
**extra_metadata,
|
||||
}
|
||||
yield Document(
|
||||
page_content=text[: self.doc_content_chars_max], metadata=metadata
|
||||
)
|
||||
os.remove(doc_file_name)
|
||||
115
venv/Lib/site-packages/langchain_community/utilities/asknews.py
Normal file
115
venv/Lib/site-packages/langchain_community/utilities/asknews.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Util that calls AskNews api."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
|
||||
class AskNewsAPIWrapper(BaseModel):
|
||||
"""Wrapper for AskNews API."""
|
||||
|
||||
asknews_sync: Any = None #: :meta private:
|
||||
asknews_async: Any = None #: :meta private:
|
||||
asknews_client_id: Optional[str] = None
|
||||
"""Client ID for the AskNews API."""
|
||||
asknews_client_secret: Optional[str] = None
|
||||
"""Client Secret for the AskNews API."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api credentials and python package exists in environment."""
|
||||
|
||||
asknews_client_id = get_from_dict_or_env(
|
||||
values, "asknews_client_id", "ASKNEWS_CLIENT_ID"
|
||||
)
|
||||
asknews_client_secret = get_from_dict_or_env(
|
||||
values, "asknews_client_secret", "ASKNEWS_CLIENT_SECRET"
|
||||
)
|
||||
|
||||
try:
|
||||
import asknews_sdk
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"AskNews python package not found. "
|
||||
"Please install it with `pip install asknews`."
|
||||
)
|
||||
|
||||
an_sync = asknews_sdk.AskNewsSDK(
|
||||
client_id=asknews_client_id,
|
||||
client_secret=asknews_client_secret,
|
||||
scopes=["news"],
|
||||
)
|
||||
an_async = asknews_sdk.AsyncAskNewsSDK(
|
||||
client_id=asknews_client_id,
|
||||
client_secret=asknews_client_secret,
|
||||
scopes=["news"],
|
||||
)
|
||||
|
||||
values["asknews_sync"] = an_sync
|
||||
values["asknews_async"] = an_async
|
||||
values["asknews_client_id"] = asknews_client_id
|
||||
values["asknews_client_secret"] = asknews_client_secret
|
||||
|
||||
return values
|
||||
|
||||
def search_news(
|
||||
self, query: str, max_results: int = 10, hours_back: int = 0
|
||||
) -> str:
|
||||
"""Search news in AskNews API synchronously."""
|
||||
if hours_back > 48:
|
||||
method = "kw"
|
||||
historical = True
|
||||
start = int((datetime.now() - timedelta(hours=hours_back)).timestamp())
|
||||
stop = int(datetime.now().timestamp())
|
||||
else:
|
||||
historical = False
|
||||
method = "nl"
|
||||
start = None
|
||||
stop = None
|
||||
|
||||
response = self.asknews_sync.news.search_news(
|
||||
query=query,
|
||||
n_articles=max_results,
|
||||
method=method,
|
||||
historical=historical,
|
||||
start_timestamp=start,
|
||||
end_timestamp=stop,
|
||||
return_type="string",
|
||||
)
|
||||
return response.as_string
|
||||
|
||||
async def asearch_news(
|
||||
self, query: str, max_results: int = 10, hours_back: int = 0
|
||||
) -> str:
|
||||
"""Search news in AskNews API asynchronously."""
|
||||
if hours_back > 48:
|
||||
method = "kw"
|
||||
historical = True
|
||||
start = int((datetime.now() - timedelta(hours=hours_back)).timestamp())
|
||||
stop = int(datetime.now().timestamp())
|
||||
else:
|
||||
historical = False
|
||||
method = "nl"
|
||||
start = None
|
||||
stop = None
|
||||
|
||||
response = await self.asknews_async.news.search_news(
|
||||
query=query,
|
||||
n_articles=max_results,
|
||||
method=method,
|
||||
historical=historical,
|
||||
start_timestamp=start,
|
||||
end_timestamp=stop,
|
||||
return_type="string",
|
||||
)
|
||||
return response.as_string
|
||||
171
venv/Lib/site-packages/langchain_community/utilities/astradb.py
Normal file
171
venv/Lib/site-packages/langchain_community/utilities/astradb.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from asyncio import InvalidStateError, Task
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Awaitable, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import (
|
||||
AstraDB,
|
||||
AsyncAstraDB,
|
||||
)
|
||||
|
||||
|
||||
class SetupMode(Enum):
|
||||
"""Setup mode for AstraDBEnvironment as enumerator."""
|
||||
|
||||
SYNC = 1
|
||||
ASYNC = 2
|
||||
OFF = 3
|
||||
|
||||
|
||||
class _AstraDBEnvironment:
|
||||
def __init__(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
astra_db = astra_db_client
|
||||
async_astra_db = async_astra_db_client
|
||||
self.namespace = namespace
|
||||
|
||||
try:
|
||||
from astrapy.db import (
|
||||
AstraDB,
|
||||
AsyncAstraDB,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import a recent astrapy python package. "
|
||||
"Please install it with `pip install --upgrade astrapy`."
|
||||
)
|
||||
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None or async_astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
|
||||
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
|
||||
)
|
||||
|
||||
if token and api_endpoint:
|
||||
astra_db = AstraDB(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
async_astra_db = AsyncAstraDB(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
|
||||
if astra_db:
|
||||
self.astra_db = astra_db
|
||||
if async_astra_db:
|
||||
self.async_astra_db = async_astra_db
|
||||
else:
|
||||
self.async_astra_db = AsyncAstraDB(
|
||||
token=self.astra_db.token,
|
||||
api_endpoint=self.astra_db.base_url,
|
||||
api_path=self.astra_db.api_path,
|
||||
api_version=self.astra_db.api_version,
|
||||
namespace=self.astra_db.namespace,
|
||||
)
|
||||
elif async_astra_db:
|
||||
self.async_astra_db = async_astra_db
|
||||
self.astra_db = AstraDB(
|
||||
token=self.async_astra_db.token,
|
||||
api_endpoint=self.async_astra_db.base_url,
|
||||
api_path=self.async_astra_db.api_path,
|
||||
api_version=self.async_astra_db.api_version,
|
||||
namespace=self.async_astra_db.namespace,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
|
||||
"'token' and 'api_endpoint'"
|
||||
)
|
||||
|
||||
|
||||
class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
pre_delete_collection: bool = False,
|
||||
embedding_dimension: Union[int, Awaitable[int], None] = None,
|
||||
metric: Optional[str] = None,
|
||||
) -> None:
|
||||
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
|
||||
|
||||
super().__init__(
|
||||
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
|
||||
)
|
||||
self.collection_name = collection_name
|
||||
self.collection = AstraDBCollection(
|
||||
collection_name=collection_name,
|
||||
astra_db=self.astra_db,
|
||||
)
|
||||
|
||||
self.async_collection = AsyncAstraDBCollection(
|
||||
collection_name=collection_name,
|
||||
astra_db=self.async_astra_db,
|
||||
)
|
||||
|
||||
self.async_setup_db_task: Optional[Task] = None
|
||||
if setup_mode == SetupMode.ASYNC:
|
||||
async_astra_db = self.async_astra_db
|
||||
|
||||
async def _setup_db() -> None:
|
||||
if pre_delete_collection:
|
||||
await async_astra_db.delete_collection(collection_name)
|
||||
if inspect.isawaitable(embedding_dimension):
|
||||
dimension: Optional[int] = await embedding_dimension
|
||||
else:
|
||||
dimension = embedding_dimension
|
||||
await async_astra_db.create_collection(
|
||||
collection_name, dimension=dimension, metric=metric
|
||||
)
|
||||
|
||||
self.async_setup_db_task = asyncio.create_task(_setup_db())
|
||||
elif setup_mode == SetupMode.SYNC:
|
||||
if pre_delete_collection:
|
||||
self.astra_db.delete_collection(collection_name)
|
||||
if inspect.isawaitable(embedding_dimension):
|
||||
raise ValueError(
|
||||
"Cannot use an awaitable embedding_dimension with async_setup "
|
||||
"set to False"
|
||||
)
|
||||
self.astra_db.create_collection(
|
||||
collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metric=metric,
|
||||
)
|
||||
|
||||
def ensure_db_setup(self) -> None:
|
||||
if self.async_setup_db_task:
|
||||
try:
|
||||
self.async_setup_db_task.result()
|
||||
except InvalidStateError:
|
||||
raise ValueError(
|
||||
"Asynchronous setup of the DB not finished. "
|
||||
"NB: AstraDB components sync methods shouldn't be called from the "
|
||||
"event loop. Consider using their async equivalents."
|
||||
)
|
||||
|
||||
async def aensure_db_setup(self) -> None:
|
||||
if self.async_setup_db_task:
|
||||
await self.async_setup_db_task
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Util that calls Lambda."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
|
||||
class LambdaWrapper(BaseModel):
|
||||
"""Wrapper for AWS Lambda SDK.
|
||||
To use, you should have the ``boto3`` package installed
|
||||
and a lambda functions built from the AWS Console or
|
||||
CLI. Set up your AWS credentials with ``aws configure``
|
||||
|
||||
Example:
|
||||
.. code-block:: bash
|
||||
|
||||
pip install boto3
|
||||
|
||||
aws configure
|
||||
|
||||
"""
|
||||
|
||||
lambda_client: Any = None #: :meta private:
|
||||
"""The configured boto3 client"""
|
||||
function_name: Optional[str] = None
|
||||
"""The name of your lambda function"""
|
||||
awslambda_tool_name: Optional[str] = None
|
||||
"""If passing to an agent as a tool, the tool name"""
|
||||
awslambda_tool_description: Optional[str] = None
|
||||
"""If passing to an agent as a tool, the description"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that python package exists in environment."""
|
||||
|
||||
try:
|
||||
import boto3
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"boto3 is not installed. Please install it with `pip install boto3`"
|
||||
)
|
||||
|
||||
values["lambda_client"] = boto3.client("lambda")
|
||||
return values
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""
|
||||
Invokes the lambda function and returns the
|
||||
result.
|
||||
|
||||
Args:
|
||||
query: an input to passed to the lambda
|
||||
function as the ``body`` of a JSON
|
||||
object.
|
||||
"""
|
||||
res = self.lambda_client.invoke(
|
||||
FunctionName=self.function_name,
|
||||
InvocationType="RequestResponse",
|
||||
Payload=json.dumps({"body": query}),
|
||||
)
|
||||
|
||||
try:
|
||||
payload_stream = res["Payload"]
|
||||
payload_string = payload_stream.read().decode("utf-8")
|
||||
answer = json.loads(payload_string)["body"]
|
||||
|
||||
except StopIteration:
|
||||
return "Failed to parse response from Lambda"
|
||||
|
||||
if answer is None or answer == "":
|
||||
# We don't want to return the assumption alone if answer is empty
|
||||
return "Request failed."
|
||||
else:
|
||||
return f"Result: {answer}"
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Util that calls bibtexparser."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPTIONAL_FIELDS = [
|
||||
"annotate",
|
||||
"booktitle",
|
||||
"editor",
|
||||
"howpublished",
|
||||
"journal",
|
||||
"keywords",
|
||||
"note",
|
||||
"organization",
|
||||
"publisher",
|
||||
"school",
|
||||
"series",
|
||||
"type",
|
||||
"doi",
|
||||
"issn",
|
||||
"isbn",
|
||||
]
|
||||
|
||||
|
||||
class BibtexparserWrapper(BaseModel):
|
||||
"""Wrapper around bibtexparser.
|
||||
|
||||
To use, you should have the ``bibtexparser`` python package installed.
|
||||
https://bibtexparser.readthedocs.io/en/master/
|
||||
|
||||
This wrapper will use bibtexparser to load a collection of references from
|
||||
a bibtex file and fetch document summaries.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that the python package exists in environment."""
|
||||
try:
|
||||
import bibtexparser # noqa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import bibtexparser python package. "
|
||||
"Please install it with `pip install bibtexparser`."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def load_bibtex_entries(self, path: str) -> List[Dict[str, Any]]:
|
||||
"""Load bibtex entries from the bibtex file at the given path."""
|
||||
import bibtexparser
|
||||
|
||||
with open(path) as file:
|
||||
entries = bibtexparser.load(file).entries
|
||||
return entries
|
||||
|
||||
def get_metadata(
|
||||
self, entry: Mapping[str, Any], load_extra: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Get metadata for the given entry."""
|
||||
publication = entry.get("journal") or entry.get("booktitle")
|
||||
if "url" in entry:
|
||||
url = entry["url"]
|
||||
elif "doi" in entry:
|
||||
url = f"https://doi.org/{entry['doi']}"
|
||||
else:
|
||||
url = None
|
||||
meta = {
|
||||
"id": entry.get("ID"),
|
||||
"published_year": entry.get("year"),
|
||||
"title": entry.get("title"),
|
||||
"publication": publication,
|
||||
"authors": entry.get("author"),
|
||||
"abstract": entry.get("abstract"),
|
||||
"url": url,
|
||||
}
|
||||
if load_extra:
|
||||
for field in OPTIONAL_FIELDS:
|
||||
meta[field] = entry.get(field)
|
||||
return {k: v for k, v in meta.items() if v is not None}
|
||||
@@ -0,0 +1,117 @@
|
||||
"""Util that calls Bing Search."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
# BING_SEARCH_ENDPOINT is the default endpoint for Bing Web Search API.
|
||||
# Currently There are two web-based Bing Search services available on Azure,
|
||||
# i.e. Bing Web Search[1] and Bing Custom Search[2]. Compared to Bing Custom Search,
|
||||
# Both services that provides a wide range of search results, while Bing Custom
|
||||
# Search requires you to provide an additional custom search instance, `customConfig`.
|
||||
# Both services are available for BingSearchAPIWrapper.
|
||||
# History of Azure Bing Search API:
|
||||
# Before shown in Azure Marketplace as a separate service, Bing Search APIs were
|
||||
# part of Azure Cognitive Services, the endpoint of which is unique, and the user
|
||||
# must specify the endpoint when making a request. After transitioning to Azure
|
||||
# Marketplace, the endpoint is standardized and the user does not need to specify
|
||||
# the endpoint[3].
|
||||
# Reference:
|
||||
# 1. https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/overview
|
||||
# 2. https://learn.microsoft.com/en-us/bing/search-apis/bing-custom-search/overview
|
||||
# 3. https://azure.microsoft.com/en-in/updates/bing-search-apis-will-transition-from-azure-cognitive-services-to-azure-marketplace-on-31-october-2023/
|
||||
DEFAULT_BING_SEARCH_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
|
||||
|
||||
|
||||
class BingSearchAPIWrapper(BaseModel):
|
||||
"""Wrapper for Bing Web Search API."""
|
||||
|
||||
bing_subscription_key: str
|
||||
bing_search_url: str
|
||||
k: int = 10
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the search request."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
def _bing_search_results(self, search_term: str, count: int) -> List[dict]:
|
||||
headers = {"Ocp-Apim-Subscription-Key": self.bing_subscription_key}
|
||||
params = {
|
||||
"q": search_term,
|
||||
"count": count,
|
||||
"textDecorations": True,
|
||||
"textFormat": "HTML",
|
||||
**self.search_kwargs,
|
||||
}
|
||||
response = requests.get(
|
||||
self.bing_search_url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
search_results = response.json()
|
||||
if "webPages" in search_results:
|
||||
return search_results["webPages"]["value"]
|
||||
return []
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
bing_subscription_key = get_from_dict_or_env(
|
||||
values, "bing_subscription_key", "BING_SUBSCRIPTION_KEY"
|
||||
)
|
||||
values["bing_subscription_key"] = bing_subscription_key
|
||||
|
||||
bing_search_url = get_from_dict_or_env(
|
||||
values,
|
||||
"bing_search_url",
|
||||
"BING_SEARCH_URL",
|
||||
default=DEFAULT_BING_SEARCH_ENDPOINT,
|
||||
)
|
||||
|
||||
values["bing_search_url"] = bing_search_url
|
||||
|
||||
return values
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Run query through BingSearch and parse result."""
|
||||
snippets = []
|
||||
results = self._bing_search_results(query, count=self.k)
|
||||
if len(results) == 0:
|
||||
return "No good Bing Search Result was found"
|
||||
for result in results:
|
||||
snippets.append(result["snippet"])
|
||||
|
||||
return " ".join(snippets)
|
||||
|
||||
def results(self, query: str, num_results: int) -> List[Dict]:
|
||||
"""Run query through BingSearch and return metadata.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
num_results: The number of results to return.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries with the following keys:
|
||||
snippet - The description of the result.
|
||||
title - The title of the result.
|
||||
link - The link to the result.
|
||||
"""
|
||||
metadata_results = []
|
||||
results = self._bing_search_results(query, count=num_results)
|
||||
if len(results) == 0:
|
||||
return [{"Result": "No good Bing Search Result was found"}]
|
||||
for result in results:
|
||||
metadata_result = {
|
||||
"snippet": result["snippet"],
|
||||
"title": result["name"],
|
||||
"link": result["url"],
|
||||
}
|
||||
metadata_results.append(metadata_result)
|
||||
|
||||
return metadata_results
|
||||
@@ -0,0 +1,83 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.utils import secret_from_env
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
class BraveSearchWrapper(BaseModel):
|
||||
"""Wrapper around the Brave search engine."""
|
||||
|
||||
api_key: SecretStr = Field(
|
||||
default_factory=secret_from_env(["BRAVE_SEARCH_API_KEY"])
|
||||
)
|
||||
"""The API key to use for the Brave search engine."""
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the search request."""
|
||||
base_url: str = "https://api.search.brave.com/res/v1/web/search"
|
||||
"""The base URL for the Brave search engine."""
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Query the Brave search engine and return the results as a JSON string.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns: The results as a JSON string.
|
||||
|
||||
"""
|
||||
web_search_results = self._search_request(query=query)
|
||||
final_results = [
|
||||
{
|
||||
"title": item.get("title"),
|
||||
"link": item.get("url"),
|
||||
"snippet": " ".join(
|
||||
filter(
|
||||
None, [item.get("description"), *item.get("extra_snippets", [])]
|
||||
)
|
||||
),
|
||||
}
|
||||
for item in web_search_results
|
||||
]
|
||||
return json.dumps(final_results)
|
||||
|
||||
def download_documents(self, query: str) -> List[Document]:
|
||||
"""Query the Brave search engine and return the results as a list of Documents.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns: The results as a list of Documents.
|
||||
|
||||
"""
|
||||
results = self._search_request(query)
|
||||
return [
|
||||
Document(
|
||||
page_content=" ".join(
|
||||
filter(
|
||||
None, [item.get("description"), *item.get("extra_snippets", [])]
|
||||
)
|
||||
),
|
||||
metadata={"title": item.get("title"), "link": item.get("url")},
|
||||
)
|
||||
for item in results
|
||||
]
|
||||
|
||||
def _search_request(self, query: str) -> List[dict]:
|
||||
headers = {
|
||||
"X-Subscription-Token": self.api_key.get_secret_value(),
|
||||
"Accept": "application/json",
|
||||
}
|
||||
req = requests.PreparedRequest()
|
||||
params = {**self.search_kwargs, **{"q": query, "extra_snippets": True}}
|
||||
req.prepare_url(self.base_url, params)
|
||||
if req.url is None:
|
||||
raise ValueError("prepared url is None, this should not happen")
|
||||
|
||||
response = requests.get(req.url, headers=headers)
|
||||
if not response.ok:
|
||||
raise Exception(f"HTTP error {response.status_code}")
|
||||
|
||||
return response.json().get("web", {}).get("results", [])
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cassandra.cluster import ResponseFuture, Session
|
||||
|
||||
|
||||
async def wrapped_response_future(
|
||||
func: Callable[..., ResponseFuture], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Wrap a Cassandra response future in an asyncio future.
|
||||
|
||||
Args:
|
||||
func: The Cassandra function to call.
|
||||
*args: The arguments to pass to the Cassandra function.
|
||||
**kwargs: The keyword arguments to pass to the Cassandra function.
|
||||
|
||||
Returns:
|
||||
The result of the Cassandra function.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio_future = loop.create_future()
|
||||
response_future = func(*args, **kwargs)
|
||||
|
||||
def success_handler(_: Any) -> None:
|
||||
loop.call_soon_threadsafe(asyncio_future.set_result, response_future.result())
|
||||
|
||||
def error_handler(exc: BaseException) -> None:
|
||||
loop.call_soon_threadsafe(asyncio_future.set_exception, exc)
|
||||
|
||||
response_future.add_callbacks(success_handler, error_handler)
|
||||
return await asyncio_future
|
||||
|
||||
|
||||
async def aexecute_cql(session: Session, query: str, **kwargs: Any) -> Any:
|
||||
"""Execute a CQL query asynchronously.
|
||||
|
||||
Args:
|
||||
session: The Cassandra session to use.
|
||||
query: The CQL query to execute.
|
||||
kwargs: Additional keyword arguments to pass to the session execute method.
|
||||
|
||||
Returns:
|
||||
The result of the query.
|
||||
"""
|
||||
return await wrapped_response_future(session.execute_async, query, **kwargs)
|
||||
|
||||
|
||||
class SetupMode(Enum):
|
||||
SYNC = 1
|
||||
ASYNC = 2
|
||||
OFF = 3
|
||||
@@ -0,0 +1,662 @@
|
||||
"""Apache Cassandra database wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cassandra.cluster import ResultSet, Session
|
||||
|
||||
IGNORED_KEYSPACES = [
|
||||
"system",
|
||||
"system_auth",
|
||||
"system_distributed",
|
||||
"system_schema",
|
||||
"system_traces",
|
||||
"system_views",
|
||||
"datastax_sla",
|
||||
"data_endpoint_auth",
|
||||
]
|
||||
|
||||
|
||||
class CassandraDatabase:
|
||||
"""Apache Cassandra® database wrapper."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Optional[Session] = None,
|
||||
exclude_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
cassio_init_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
_session = self._resolve_session(session, cassio_init_kwargs)
|
||||
if not _session:
|
||||
raise ValueError("Session not provided and cannot be resolved")
|
||||
self._session = _session
|
||||
|
||||
self._exclude_keyspaces = IGNORED_KEYSPACES
|
||||
self._exclude_tables = exclude_tables or []
|
||||
self._include_tables = include_tables or []
|
||||
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
fetch: str = "all",
|
||||
**kwargs: Any,
|
||||
) -> Union[list, Dict[str, Any], ResultSet]:
|
||||
"""Execute a CQL query and return the results."""
|
||||
if fetch == "all":
|
||||
return self.fetch_all(query, **kwargs)
|
||||
elif fetch == "one":
|
||||
return self.fetch_one(query, **kwargs)
|
||||
elif fetch == "cursor":
|
||||
return self._fetch(query, **kwargs)
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one', 'all', or 'cursor'")
|
||||
|
||||
def _fetch(self, query: str, **kwargs: Any) -> ResultSet:
|
||||
clean_query = self._validate_cql(query, "SELECT")
|
||||
return self._session.execute(clean_query, **kwargs)
|
||||
|
||||
def fetch_all(self, query: str, **kwargs: Any) -> list:
|
||||
return list(self._fetch(query, **kwargs))
|
||||
|
||||
def fetch_one(self, query: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
result = self._fetch(query, **kwargs)
|
||||
return result.one()._asdict() if result else {}
|
||||
|
||||
def get_keyspace_tables(self, keyspace: str) -> List[Table]:
|
||||
"""Get the Table objects for the specified keyspace."""
|
||||
schema = self._resolve_schema([keyspace])
|
||||
if keyspace in schema:
|
||||
return schema[keyspace]
|
||||
else:
|
||||
return []
|
||||
|
||||
# This is a more basic string building function that doesn't use a query builder
|
||||
# or prepared statements
|
||||
# TODO: Refactor to use prepared statements
|
||||
def get_table_data(
|
||||
self, keyspace: str, table: str, predicate: str, limit: int
|
||||
) -> str:
|
||||
"""Get data from the specified table in the specified keyspace."""
|
||||
|
||||
query = f"SELECT * FROM {keyspace}.{table}"
|
||||
|
||||
if predicate:
|
||||
query += f" WHERE {predicate}"
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
query += ";"
|
||||
|
||||
result = self.fetch_all(query)
|
||||
data = "\n".join(str(row) for row in result)
|
||||
return data
|
||||
|
||||
def get_context(self) -> Dict[str, Any]:
|
||||
"""Return db context that you may want in agent prompt."""
|
||||
keyspaces = self._fetch_keyspaces()
|
||||
return {"keyspaces": ", ".join(keyspaces)}
|
||||
|
||||
def format_keyspace_to_markdown(
|
||||
self, keyspace: str, tables: Optional[List[Table]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generates a markdown representation of the schema for a specific keyspace
|
||||
by iterating over all tables within that keyspace and calling their
|
||||
as_markdown method.
|
||||
|
||||
Args:
|
||||
keyspace: The name of the keyspace to generate markdown documentation for.
|
||||
tables: list of tables in the keyspace; it will be resolved if not provided.
|
||||
|
||||
Returns:
|
||||
A string containing the markdown representation of the specified
|
||||
keyspace schema.
|
||||
"""
|
||||
if not tables:
|
||||
tables = self.get_keyspace_tables(keyspace)
|
||||
|
||||
if tables:
|
||||
output = f"## Keyspace: {keyspace}\n\n"
|
||||
if tables:
|
||||
for table in tables:
|
||||
output += table.as_markdown(include_keyspace=False, header_level=3)
|
||||
output += "\n\n"
|
||||
else:
|
||||
output += "No tables present in keyspace\n\n"
|
||||
|
||||
return output
|
||||
else:
|
||||
return ""
|
||||
|
||||
def format_schema_to_markdown(self) -> str:
|
||||
"""
|
||||
Generates a markdown representation of the schema for all keyspaces and tables
|
||||
within the CassandraDatabase instance. This method utilizes the
|
||||
format_keyspace_to_markdown method to create markdown sections for each
|
||||
keyspace, assembling them into a comprehensive schema document.
|
||||
|
||||
Iterates through each keyspace in the database, utilizing
|
||||
format_keyspace_to_markdown to generate markdown for each keyspace's schema,
|
||||
including details of its tables. These sections are concatenated to form a
|
||||
single markdown document that represents the schema of the entire database or
|
||||
the subset of keyspaces that have been resolved in this instance.
|
||||
|
||||
Returns:
|
||||
A markdown string that documents the schema of all resolved keyspaces and
|
||||
their tables within this CassandraDatabase instance. This includes keyspace
|
||||
names, table names, comments, columns, partition keys, clustering keys,
|
||||
and indexes for each table.
|
||||
"""
|
||||
schema = self._resolve_schema()
|
||||
output = "# Cassandra Database Schema\n\n"
|
||||
for keyspace, tables in schema.items():
|
||||
output += f"{self.format_keyspace_to_markdown(keyspace, tables)}\n\n"
|
||||
return output
|
||||
|
||||
def _validate_cql(self, cql: str, type: str = "SELECT") -> str:
|
||||
"""
|
||||
Validates a CQL query string for basic formatting and safety checks.
|
||||
Ensures that `cql` starts with the specified type (e.g., SELECT) and does
|
||||
not contain content that could indicate CQL injection vulnerabilities.
|
||||
|
||||
Args:
|
||||
cql: The CQL query string to be validated.
|
||||
type: The expected starting keyword of the query, used to verify
|
||||
that the query begins with the correct operation type
|
||||
(e.g., "SELECT", "UPDATE"). Defaults to "SELECT".
|
||||
|
||||
Returns:
|
||||
The trimmed and validated CQL query string without a trailing semicolon.
|
||||
|
||||
Raises:
|
||||
ValueError: If the value of `type` is not supported
|
||||
DatabaseError: If `cql` is considered unsafe
|
||||
"""
|
||||
SUPPORTED_TYPES = ["SELECT"]
|
||||
if type and type.upper() not in SUPPORTED_TYPES:
|
||||
raise ValueError(
|
||||
f"""Unsupported CQL type: {type}. Supported types:
|
||||
{SUPPORTED_TYPES}"""
|
||||
)
|
||||
|
||||
# Basic sanity checks
|
||||
cql_trimmed = cql.strip()
|
||||
if not cql_trimmed.upper().startswith(type.upper()):
|
||||
raise DatabaseError(f"CQL must start with {type.upper()}.")
|
||||
|
||||
# Allow a trailing semicolon, but remove (it is optional with the Python driver)
|
||||
cql_trimmed = cql_trimmed.rstrip(";")
|
||||
|
||||
# Consider content within matching quotes to be "safe"
|
||||
# Remove single-quoted strings
|
||||
cql_sanitized = re.sub(r"'.*?'", "", cql_trimmed)
|
||||
|
||||
# Remove double-quoted strings
|
||||
cql_sanitized = re.sub(r'".*?"', "", cql_sanitized)
|
||||
|
||||
# Find unsafe content in the remaining CQL
|
||||
if ";" in cql_sanitized:
|
||||
raise DatabaseError(
|
||||
"""Potentially unsafe CQL, as it contains a ; at a
|
||||
place other than the end or within quotation marks."""
|
||||
)
|
||||
|
||||
# The trimmed query, before modifications
|
||||
return cql_trimmed
|
||||
|
||||
def _fetch_keyspaces(self, keyspaces: Optional[List[str]] = None) -> List[str]:
|
||||
"""
|
||||
Fetches a list of keyspace names from the Cassandra database. The list can be
|
||||
filtered by a provided list of keyspace names or by excluding predefined
|
||||
keyspaces.
|
||||
|
||||
Args:
|
||||
keyspaces: A list of keyspace names to specifically include.
|
||||
If provided and not empty, the method returns only the keyspaces
|
||||
present in this list.
|
||||
If not provided or empty, the method returns all keyspaces except those
|
||||
specified in the _exclude_keyspaces attribute.
|
||||
|
||||
Returns:
|
||||
A list of keyspace names according to the filtering criteria.
|
||||
"""
|
||||
all_keyspaces = self.fetch_all(
|
||||
"SELECT keyspace_name FROM system_schema.keyspaces"
|
||||
)
|
||||
|
||||
# Filtering keyspaces based on 'keyspace_list' and '_exclude_keyspaces'
|
||||
filtered_keyspaces = []
|
||||
for ks in all_keyspaces:
|
||||
if not isinstance(ks, Dict):
|
||||
continue # Skip if the row is not a dictionary.
|
||||
|
||||
keyspace_name = ks["keyspace_name"]
|
||||
if keyspaces and keyspace_name in keyspaces:
|
||||
filtered_keyspaces.append(keyspace_name)
|
||||
elif not keyspaces and keyspace_name not in self._exclude_keyspaces:
|
||||
filtered_keyspaces.append(keyspace_name)
|
||||
|
||||
return filtered_keyspaces
|
||||
|
||||
def _format_keyspace_query(self, query: str, keyspaces: List[str]) -> str:
|
||||
# Construct IN clause for CQL query
|
||||
keyspace_in_clause = ", ".join([f"'{ks}'" for ks in keyspaces])
|
||||
return f"""{query} WHERE keyspace_name IN ({keyspace_in_clause})"""
|
||||
|
||||
def _fetch_tables_data(self, keyspaces: List[str]) -> list:
|
||||
"""Fetches tables schema data, filtered by a list of keyspaces.
|
||||
This method allows for efficiently fetching schema information for multiple
|
||||
keyspaces in a single operation, enabling applications to programmatically
|
||||
analyze or document the database schema.
|
||||
|
||||
Args:
|
||||
keyspaces: A list of keyspace names from which to fetch tables schema data.
|
||||
|
||||
Returns:
|
||||
Dictionaries of table details (keyspace name, table name, and comment).
|
||||
"""
|
||||
tables_query = self._format_keyspace_query(
|
||||
"SELECT keyspace_name, table_name, comment FROM system_schema.tables",
|
||||
keyspaces,
|
||||
)
|
||||
return self.fetch_all(tables_query)
|
||||
|
||||
def _fetch_columns_data(self, keyspaces: List[str]) -> list:
|
||||
"""Fetches columns schema data, filtered by a list of keyspaces.
|
||||
This method allows for efficiently fetching schema information for multiple
|
||||
keyspaces in a single operation, enabling applications to programmatically
|
||||
analyze or document the database schema.
|
||||
|
||||
Args:
|
||||
keyspaces: A list of keyspace names from which to fetch tables schema data.
|
||||
|
||||
Returns:
|
||||
Dictionaries of column details (keyspace name, table name, column name,
|
||||
type, kind, and position).
|
||||
"""
|
||||
tables_query = self._format_keyspace_query(
|
||||
"""
|
||||
SELECT keyspace_name, table_name, column_name, type, kind,
|
||||
clustering_order, position
|
||||
FROM system_schema.columns
|
||||
""",
|
||||
keyspaces,
|
||||
)
|
||||
return self.fetch_all(tables_query)
|
||||
|
||||
def _fetch_indexes_data(self, keyspaces: List[str]) -> list:
|
||||
"""Fetches indexes schema data, filtered by a list of keyspaces.
|
||||
This method allows for efficiently fetching schema information for multiple
|
||||
keyspaces in a single operation, enabling applications to programmatically
|
||||
analyze or document the database schema.
|
||||
|
||||
Args:
|
||||
keyspaces: A list of keyspace names from which to fetch tables schema data.
|
||||
|
||||
Returns:
|
||||
Dictionaries of index details (keyspace name, table name, index name, kind,
|
||||
and options).
|
||||
"""
|
||||
tables_query = self._format_keyspace_query(
|
||||
"""
|
||||
SELECT keyspace_name, table_name, index_name,
|
||||
kind, options
|
||||
FROM system_schema.indexes
|
||||
""",
|
||||
keyspaces,
|
||||
)
|
||||
return self.fetch_all(tables_query)
|
||||
|
||||
def _resolve_schema(
|
||||
self, keyspaces: Optional[List[str]] = None
|
||||
) -> Dict[str, List[Table]]:
|
||||
"""
|
||||
Efficiently fetches and organizes Cassandra table schema information,
|
||||
such as comments, columns, and indexes, into a dictionary mapping keyspace
|
||||
names to lists of Table objects.
|
||||
|
||||
Args:
|
||||
keyspaces: An optional list of keyspace names from which to fetch tables
|
||||
schema data.
|
||||
|
||||
Returns:
|
||||
A dictionary with keyspace names as keys and lists of Table objects as
|
||||
values, where each Table object is populated with schema details
|
||||
appropriate for its keyspace and table name.
|
||||
"""
|
||||
if not keyspaces:
|
||||
keyspaces = self._fetch_keyspaces()
|
||||
|
||||
tables_data = self._fetch_tables_data(keyspaces)
|
||||
columns_data = self._fetch_columns_data(keyspaces)
|
||||
indexes_data = self._fetch_indexes_data(keyspaces)
|
||||
|
||||
keyspace_dict: dict = {}
|
||||
for table_data in tables_data:
|
||||
keyspace = table_data.keyspace_name
|
||||
table_name = table_data.table_name
|
||||
comment = table_data.comment
|
||||
|
||||
if self._include_tables and table_name not in self._include_tables:
|
||||
continue
|
||||
|
||||
if self._exclude_tables and table_name in self._exclude_tables:
|
||||
continue
|
||||
|
||||
# Filter columns and indexes for this table
|
||||
table_columns = [
|
||||
(c.column_name, c.type)
|
||||
for c in columns_data
|
||||
if c.keyspace_name == keyspace and c.table_name == table_name
|
||||
]
|
||||
|
||||
partition_keys = [
|
||||
c.column_name
|
||||
for c in columns_data
|
||||
if c.kind == "partition_key"
|
||||
and c.keyspace_name == keyspace
|
||||
and c.table_name == table_name
|
||||
]
|
||||
|
||||
clustering_keys = [
|
||||
(c.column_name, c.clustering_order)
|
||||
for c in columns_data
|
||||
if c.kind == "clustering"
|
||||
and c.keyspace_name == keyspace
|
||||
and c.table_name == table_name
|
||||
]
|
||||
|
||||
table_indexes = [
|
||||
(c.index_name, c.kind, c.options)
|
||||
for c in indexes_data
|
||||
if c.keyspace_name == keyspace and c.table_name == table_name
|
||||
]
|
||||
|
||||
table_obj = Table(
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
comment=comment,
|
||||
columns=table_columns,
|
||||
partition=partition_keys,
|
||||
clustering=clustering_keys,
|
||||
indexes=table_indexes,
|
||||
)
|
||||
|
||||
if keyspace not in keyspace_dict:
|
||||
keyspace_dict[keyspace] = []
|
||||
keyspace_dict[keyspace].append(table_obj)
|
||||
|
||||
return keyspace_dict
|
||||
|
||||
@staticmethod
|
||||
def _resolve_session(
|
||||
session: Optional[Session] = None,
|
||||
cassio_init_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Session]:
|
||||
"""
|
||||
Attempts to resolve and return a Session object for use in database operations.
|
||||
|
||||
This function follows a specific order of precedence to determine the
|
||||
appropriate session to use:
|
||||
1. `session` parameter if given,
|
||||
2. Existing `cassio` session,
|
||||
3. A new `cassio` session derived from `cassio_init_kwargs`,
|
||||
4. `None`
|
||||
|
||||
Args:
|
||||
session: An optional session to use directly.
|
||||
cassio_init_kwargs: An optional dictionary of keyword arguments to `cassio`.
|
||||
|
||||
Returns:
|
||||
The resolved session object if successful, or `None` if the session
|
||||
cannot be resolved.
|
||||
|
||||
Raises:
|
||||
ValueError: If `cassio_init_kwargs` is provided but is not a dictionary of
|
||||
keyword arguments.
|
||||
"""
|
||||
|
||||
# Prefer given session
|
||||
if session:
|
||||
return session
|
||||
|
||||
# If a session is not provided, create one using cassio if available
|
||||
# dynamically import cassio to avoid circular imports
|
||||
try:
|
||||
import cassio.config
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"cassio package not found, please install with `pip install cassio`"
|
||||
)
|
||||
|
||||
# Use pre-existing session on cassio
|
||||
s = cassio.config.resolve_session()
|
||||
if s:
|
||||
return s
|
||||
|
||||
# Try to init and return cassio session
|
||||
if cassio_init_kwargs:
|
||||
if isinstance(cassio_init_kwargs, dict):
|
||||
cassio.init(**cassio_init_kwargs)
|
||||
s = cassio.config.check_resolve_session()
|
||||
return s
|
||||
else:
|
||||
raise ValueError("cassio_init_kwargs must be a keyword dictionary")
|
||||
|
||||
# return None if we're not able to resolve
|
||||
return None
|
||||
|
||||
|
||||
class DatabaseError(Exception):
|
||||
"""Exception raised for errors in the database schema.
|
||||
|
||||
Attributes:
|
||||
message -- explanation of the error
|
||||
"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class Table(BaseModel):
|
||||
keyspace: str
|
||||
"""The keyspace in which the table exists."""
|
||||
|
||||
table_name: str
|
||||
"""The name of the table."""
|
||||
|
||||
comment: Optional[str] = None
|
||||
"""The comment associated with the table."""
|
||||
|
||||
columns: List[Tuple[str, str]] = Field(default_factory=list)
|
||||
partition: List[str] = Field(default_factory=list)
|
||||
clustering: List[Tuple[str, str]] = Field(default_factory=list)
|
||||
indexes: List[Tuple[str, str, str]] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_required_fields(self) -> Self:
|
||||
if not self.columns:
|
||||
raise ValueError("non-empty column list for must be provided")
|
||||
if not self.partition:
|
||||
raise ValueError("non-empty partition list must be provided")
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_database(
|
||||
cls, keyspace: str, table_name: str, db: CassandraDatabase
|
||||
) -> Table:
|
||||
columns, partition, clustering = cls._resolve_columns(keyspace, table_name, db)
|
||||
return cls(
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
comment=cls._resolve_comment(keyspace, table_name, db),
|
||||
columns=columns,
|
||||
partition=partition,
|
||||
clustering=clustering,
|
||||
indexes=cls._resolve_indexes(keyspace, table_name, db),
|
||||
)
|
||||
|
||||
def as_markdown(
|
||||
self, include_keyspace: bool = True, header_level: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generates a Markdown representation of the Cassandra table schema, allowing for
|
||||
customizable header levels for the table name section.
|
||||
|
||||
Args:
|
||||
include_keyspace: If True, includes the keyspace in the output.
|
||||
Defaults to True.
|
||||
header_level: Specifies the markdown header level for the table name.
|
||||
If None, the table name is included without a header.
|
||||
Defaults to None (no header level).
|
||||
|
||||
Returns:
|
||||
A string in Markdown format detailing the table name
|
||||
(with optional header level), keyspace (optional), comment, columns,
|
||||
partition keys, clustering keys (with optional clustering order),
|
||||
and indexes.
|
||||
"""
|
||||
output = ""
|
||||
if header_level is not None:
|
||||
output += f"{'#' * header_level} "
|
||||
output += f"Table Name: {self.table_name}\n"
|
||||
|
||||
if include_keyspace:
|
||||
output += f"- Keyspace: {self.keyspace}\n"
|
||||
if self.comment:
|
||||
output += f"- Comment: {self.comment}\n"
|
||||
|
||||
output += "- Columns\n"
|
||||
for column, type in self.columns:
|
||||
output += f" - {column} ({type})\n"
|
||||
|
||||
output += f"- Partition Keys: ({', '.join(self.partition)})\n"
|
||||
output += "- Clustering Keys: "
|
||||
if self.clustering:
|
||||
cluster_list = []
|
||||
for column, clustering_order in self.clustering:
|
||||
if clustering_order.lower() == "none":
|
||||
cluster_list.append(column)
|
||||
else:
|
||||
cluster_list.append(f"{column} {clustering_order}")
|
||||
output += f"({', '.join(cluster_list)})\n"
|
||||
|
||||
if self.indexes:
|
||||
output += "- Indexes\n"
|
||||
for name, kind, options in self.indexes:
|
||||
output += f" - {name} : kind={kind}, options={options}\n"
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _resolve_comment(
|
||||
keyspace: str, table_name: str, db: CassandraDatabase
|
||||
) -> Optional[str]:
|
||||
result = db.run(
|
||||
f"""SELECT comment
|
||||
FROM system_schema.tables
|
||||
WHERE keyspace_name = '{keyspace}'
|
||||
AND table_name = '{table_name}';""",
|
||||
fetch="one",
|
||||
)
|
||||
|
||||
if isinstance(result, dict):
|
||||
comment = result.get("comment")
|
||||
if comment:
|
||||
return comment
|
||||
else:
|
||||
return None # Default comment if none is found
|
||||
else:
|
||||
raise ValueError(
|
||||
f"""Unexpected result type from db.run:
|
||||
{type(result).__name__}"""
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_columns(
|
||||
keyspace: str, table_name: str, db: CassandraDatabase
|
||||
) -> Tuple[List[Tuple[str, str]], List[str], List[Tuple[str, str]]]:
|
||||
columns = []
|
||||
partition_info = []
|
||||
cluster_info = []
|
||||
results = db.run(
|
||||
f"""SELECT column_name, type, kind, clustering_order, position
|
||||
FROM system_schema.columns
|
||||
WHERE keyspace_name = '{keyspace}'
|
||||
AND table_name = '{table_name}';"""
|
||||
)
|
||||
# Type check to ensure 'results' is a sequence of dictionaries.
|
||||
if not isinstance(results, Sequence):
|
||||
raise TypeError("Expected a sequence of dictionaries from 'run' method.")
|
||||
|
||||
for row in results:
|
||||
if not isinstance(row, Dict):
|
||||
continue # Skip if the row is not a dictionary.
|
||||
|
||||
columns.append((row["column_name"], row["type"]))
|
||||
if row["kind"] == "partition_key":
|
||||
partition_info.append((row["column_name"], row["position"]))
|
||||
elif row["kind"] == "clustering":
|
||||
cluster_info.append(
|
||||
(
|
||||
row["column_name"],
|
||||
row["clustering_order"],
|
||||
row["position"],
|
||||
)
|
||||
)
|
||||
|
||||
partition = [
|
||||
column_name for column_name, _ in sorted(partition_info, key=lambda x: x[1])
|
||||
]
|
||||
|
||||
cluster = [
|
||||
(column_name, clustering_order)
|
||||
for column_name, clustering_order, _ in sorted(
|
||||
cluster_info, key=lambda x: x[2]
|
||||
)
|
||||
]
|
||||
|
||||
return columns, partition, cluster
|
||||
|
||||
@staticmethod
|
||||
def _resolve_indexes(
|
||||
keyspace: str, table_name: str, db: CassandraDatabase
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
indexes = []
|
||||
results = db.run(
|
||||
f"""SELECT index_name, kind, options
|
||||
FROM system_schema.indexes
|
||||
WHERE keyspace_name = '{keyspace}'
|
||||
AND table_name = '{table_name}';"""
|
||||
)
|
||||
|
||||
# Type check to ensure 'results' is a sequence of dictionaries
|
||||
if not isinstance(results, Sequence):
|
||||
raise TypeError("Expected a sequence of dictionaries from 'run' method.")
|
||||
|
||||
for row in results:
|
||||
if not isinstance(row, Dict):
|
||||
continue # Skip if the row is not a dictionary.
|
||||
|
||||
# Convert 'options' to string if it's not already,
|
||||
# assuming it's JSON-like and needs conversion
|
||||
index_options = row["options"]
|
||||
if not isinstance(index_options, str):
|
||||
# Assuming index_options needs to be serialized or simply converted
|
||||
index_options = str(index_options)
|
||||
|
||||
indexes.append((row["index_name"], row["kind"], index_options))
|
||||
|
||||
return indexes
|
||||
626
venv/Lib/site-packages/langchain_community/utilities/clickup.py
Normal file
626
venv/Lib/site-packages/langchain_community/utilities/clickup.py
Normal file
@@ -0,0 +1,626 @@
|
||||
"""Util that calls clickup."""
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union
|
||||
|
||||
import requests
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
DEFAULT_URL = "https://api.clickup.com/api/v2"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Component:
|
||||
"""Base class for all components."""
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: Dict[str, Any]) -> "Component":
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task(Component):
|
||||
"""Class for a task."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
text_content: str
|
||||
description: str
|
||||
status: str
|
||||
creator_id: int
|
||||
creator_username: str
|
||||
creator_email: str
|
||||
assignees: List[Dict[str, Any]]
|
||||
watchers: List[Dict[str, Any]]
|
||||
priority: Optional[str]
|
||||
due_date: Optional[str]
|
||||
start_date: Optional[str]
|
||||
points: int
|
||||
team_id: int
|
||||
project_id: int
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: Dict[str, Any]) -> "Task":
|
||||
priority = None if data["priority"] is None else data["priority"]["priority"]
|
||||
return cls(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
text_content=data["text_content"],
|
||||
description=data["description"],
|
||||
status=data["status"]["status"],
|
||||
creator_id=data["creator"]["id"],
|
||||
creator_username=data["creator"]["username"],
|
||||
creator_email=data["creator"]["email"],
|
||||
assignees=data["assignees"],
|
||||
watchers=data["watchers"],
|
||||
priority=priority,
|
||||
due_date=data["due_date"],
|
||||
start_date=data["start_date"],
|
||||
points=data["points"],
|
||||
team_id=data["team_id"],
|
||||
project_id=data["project"]["id"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CUList(Component):
|
||||
"""Component class for a list."""
|
||||
|
||||
folder_id: float
|
||||
name: str
|
||||
content: Optional[str] = None
|
||||
due_date: Optional[int] = None
|
||||
due_date_time: Optional[bool] = None
|
||||
priority: Optional[int] = None
|
||||
assignee: Optional[int] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: dict) -> "CUList":
|
||||
return cls(
|
||||
folder_id=data["folder_id"],
|
||||
name=data["name"],
|
||||
content=data.get("content"),
|
||||
due_date=data.get("due_date"),
|
||||
due_date_time=data.get("due_date_time"),
|
||||
priority=data.get("priority"),
|
||||
assignee=data.get("assignee"),
|
||||
status=data.get("status"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Member(Component):
|
||||
"""Component class for a member."""
|
||||
|
||||
id: int
|
||||
username: str
|
||||
email: str
|
||||
initials: str
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: Dict) -> "Member":
|
||||
return cls(
|
||||
id=data["user"]["id"],
|
||||
username=data["user"]["username"],
|
||||
email=data["user"]["email"],
|
||||
initials=data["user"]["initials"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Team(Component):
|
||||
"""Component class for a team."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
members: List[Member]
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: Dict) -> "Team":
|
||||
members = [Member.from_data(member_data) for member_data in data["members"]]
|
||||
return cls(id=data["id"], name=data["name"], members=members)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Space(Component):
|
||||
"""Component class for a space."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
private: bool
|
||||
enabled_features: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data: Dict[str, Any]) -> "Space":
|
||||
space_data = data["spaces"][0]
|
||||
enabled_features = {
|
||||
feature: value
|
||||
for feature, value in space_data["features"].items()
|
||||
if value["enabled"]
|
||||
}
|
||||
return cls(
|
||||
id=space_data["id"],
|
||||
name=space_data["name"],
|
||||
private=space_data["private"],
|
||||
enabled_features=enabled_features,
|
||||
)
|
||||
|
||||
|
||||
def parse_dict_through_component(
|
||||
data: dict, component: Type[Component], fault_tolerant: bool = False
|
||||
) -> Dict:
|
||||
"""Parse a dictionary by creating
|
||||
a component and then turning it back into a dictionary.
|
||||
|
||||
This helps with two things
|
||||
1. Extract and format data from a dictionary according to schema
|
||||
2. Provide a central place to do this in a fault-tolerant way
|
||||
|
||||
"""
|
||||
try:
|
||||
return asdict(component.from_data(data))
|
||||
except Exception as e:
|
||||
if fault_tolerant:
|
||||
warning_str = f"""Error encountered while trying to parse
|
||||
{str(data)}: {str(e)}\n Falling back to returning input data."""
|
||||
warnings.warn(warning_str)
|
||||
return data
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def extract_dict_elements_from_component_fields(
|
||||
data: dict, component: Type[Component]
|
||||
) -> dict:
|
||||
"""Extract elements from a dictionary.
|
||||
|
||||
Args:
|
||||
data: The dictionary to extract elements from.
|
||||
component: The component to extract elements from.
|
||||
|
||||
Returns:
|
||||
`dict` containing the elements from the input dictionary that are also in the
|
||||
component.
|
||||
"""
|
||||
output = {}
|
||||
for attribute in fields(component):
|
||||
if attribute.name in data:
|
||||
output[attribute.name] = data[attribute.name]
|
||||
return output
|
||||
|
||||
|
||||
def load_query(
|
||||
query: str, fault_tolerant: bool = False
|
||||
) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Parse a JSON string and return the parsed object.
|
||||
|
||||
If parsing fails, returns an error message.
|
||||
|
||||
:param query: The JSON string to parse.
|
||||
:return: A tuple containing the parsed object or None and an error message or None.
|
||||
|
||||
Exceptions:
|
||||
json.JSONDecodeError: If the input is not a valid JSON string.
|
||||
"""
|
||||
try:
|
||||
return json.loads(query), None
|
||||
except json.JSONDecodeError as e:
|
||||
if fault_tolerant:
|
||||
return (
|
||||
None,
|
||||
f"""Input must be a valid JSON. Got the following error: {str(e)}.
|
||||
"Please reformat and try again.""",
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def fetch_first_id(data: dict, key: str) -> Optional[int]:
|
||||
"""Fetch the first id from a dictionary."""
|
||||
if key in data and len(data[key]) > 0:
|
||||
if len(data[key]) > 1:
|
||||
warnings.warn(f"Found multiple {key}: {data[key]}. Defaulting to first.")
|
||||
return data[key][0]["id"]
|
||||
return None
|
||||
|
||||
|
||||
def fetch_data(url: str, access_token: str, query: Optional[dict] = None) -> dict:
|
||||
"""Fetch data from a URL."""
|
||||
headers = {"Authorization": access_token}
|
||||
response = requests.get(url, headers=headers, params=query)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_team_id(access_token: str) -> Optional[int]:
|
||||
"""Fetch the team id."""
|
||||
url = f"{DEFAULT_URL}/team"
|
||||
data = fetch_data(url, access_token)
|
||||
return fetch_first_id(data, "teams")
|
||||
|
||||
|
||||
def fetch_space_id(team_id: int, access_token: str) -> Optional[int]:
|
||||
"""Fetch the space id."""
|
||||
url = f"{DEFAULT_URL}/team/{team_id}/space"
|
||||
data = fetch_data(url, access_token, query={"archived": "false"})
|
||||
return fetch_first_id(data, "spaces")
|
||||
|
||||
|
||||
def fetch_folder_id(space_id: int, access_token: str) -> Optional[int]:
|
||||
"""Fetch the folder id."""
|
||||
url = f"{DEFAULT_URL}/space/{space_id}/folder"
|
||||
data = fetch_data(url, access_token, query={"archived": "false"})
|
||||
return fetch_first_id(data, "folders")
|
||||
|
||||
|
||||
def fetch_list_id(space_id: int, folder_id: int, access_token: str) -> Optional[int]:
|
||||
"""Fetch the list id."""
|
||||
if folder_id:
|
||||
url = f"{DEFAULT_URL}/folder/{folder_id}/list"
|
||||
else:
|
||||
url = f"{DEFAULT_URL}/space/{space_id}/list"
|
||||
|
||||
data = fetch_data(url, access_token, query={"archived": "false"})
|
||||
|
||||
# The structure to fetch list id differs based if its folderless
|
||||
if folder_id and "id" in data:
|
||||
return data["id"]
|
||||
else:
|
||||
return fetch_first_id(data, "lists")
|
||||
|
||||
|
||||
class ClickupAPIWrapper(BaseModel):
|
||||
"""Wrapper for Clickup API."""
|
||||
|
||||
access_token: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
space_id: Optional[str] = None
|
||||
folder_id: Optional[str] = None
|
||||
list_id: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_access_code_url(
|
||||
cls, oauth_client_id: str, redirect_uri: str = "https://google.com"
|
||||
) -> str:
|
||||
"""Get the URL to get an access code."""
|
||||
url = f"https://app.clickup.com/api?client_id={oauth_client_id}"
|
||||
return f"{url}&redirect_uri={redirect_uri}"
|
||||
|
||||
@classmethod
|
||||
def get_access_token(
|
||||
cls, oauth_client_id: str, oauth_client_secret: str, code: str
|
||||
) -> Optional[str]:
|
||||
"""Get the access token."""
|
||||
url = f"{DEFAULT_URL}/oauth/token"
|
||||
|
||||
params = {
|
||||
"client_id": oauth_client_id,
|
||||
"client_secret": oauth_client_secret,
|
||||
"code": code,
|
||||
}
|
||||
|
||||
response = requests.post(url, params=params)
|
||||
data = response.json()
|
||||
|
||||
if "access_token" not in data:
|
||||
print(f"Error: {data}") # noqa: T201
|
||||
if "ECODE" in data and data["ECODE"] == "OAUTH_014":
|
||||
url = ClickupAPIWrapper.get_access_code_url(oauth_client_id)
|
||||
print( # noqa: T201
|
||||
"You already used this code once. Generate a new one.",
|
||||
f"Our best guess for the url to get a new code is:\n{url}",
|
||||
)
|
||||
return None
|
||||
|
||||
return data["access_token"]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["access_token"] = get_from_dict_or_env(
|
||||
values, "access_token", "CLICKUP_ACCESS_TOKEN"
|
||||
)
|
||||
values["team_id"] = fetch_team_id(values["access_token"])
|
||||
values["space_id"] = fetch_space_id(values["team_id"], values["access_token"])
|
||||
values["folder_id"] = fetch_folder_id(
|
||||
values["space_id"], values["access_token"]
|
||||
)
|
||||
values["list_id"] = fetch_list_id(
|
||||
values["space_id"], values["folder_id"], values["access_token"]
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def attempt_parse_teams(self, input_dict: dict) -> Dict[str, List[dict]]:
|
||||
"""Parse appropriate content from the list of teams."""
|
||||
parsed_teams: Dict[str, List[dict]] = {"teams": []}
|
||||
for team in input_dict["teams"]:
|
||||
try:
|
||||
team = parse_dict_through_component(team, Team, fault_tolerant=False)
|
||||
parsed_teams["teams"].append(team)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Error parsing a team {e}")
|
||||
|
||||
return parsed_teams
|
||||
|
||||
def get_headers(
|
||||
self,
|
||||
) -> Mapping[str, Union[str, bytes]]:
|
||||
"""Get the headers for the request."""
|
||||
if not isinstance(self.access_token, str):
|
||||
raise TypeError(f"Access Token: {self.access_token}, must be str.")
|
||||
|
||||
headers = {
|
||||
"Authorization": str(self.access_token),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
return headers
|
||||
|
||||
def get_default_params(self) -> Dict:
|
||||
return {"archived": "false"}
|
||||
|
||||
def get_authorized_teams(self) -> Dict[Any, Any]:
|
||||
"""Get all teams for the user."""
|
||||
url = f"{DEFAULT_URL}/team"
|
||||
|
||||
response = requests.get(url, headers=self.get_headers())
|
||||
|
||||
data = response.json()
|
||||
parsed_teams = self.attempt_parse_teams(data)
|
||||
|
||||
return parsed_teams
|
||||
|
||||
def get_folders(self) -> Dict:
|
||||
"""
|
||||
Get all the folders for the team.
|
||||
"""
|
||||
url = f"{DEFAULT_URL}/team/" + str(self.team_id) + "/space"
|
||||
params = self.get_default_params()
|
||||
response = requests.get(url, headers=self.get_headers(), params=params)
|
||||
return {"response": response}
|
||||
|
||||
def get_task(self, query: str, fault_tolerant: bool = True) -> Dict:
|
||||
"""
|
||||
Retrieve a specific task.
|
||||
"""
|
||||
|
||||
params, error = load_query(query, fault_tolerant=True)
|
||||
if params is None:
|
||||
return {"Error": error}
|
||||
|
||||
url = f"{DEFAULT_URL}/task/{params['task_id']}"
|
||||
params = {
|
||||
"custom_task_ids": "true",
|
||||
"team_id": self.team_id,
|
||||
"include_subtasks": "true",
|
||||
}
|
||||
response = requests.get(url, headers=self.get_headers(), params=params)
|
||||
data = response.json()
|
||||
parsed_task = parse_dict_through_component(
|
||||
data, Task, fault_tolerant=fault_tolerant
|
||||
)
|
||||
|
||||
return parsed_task
|
||||
|
||||
def get_lists(self) -> Dict:
|
||||
"""
|
||||
Get all available lists.
|
||||
"""
|
||||
|
||||
url = f"{DEFAULT_URL}/folder/{self.folder_id}/list"
|
||||
params = self.get_default_params()
|
||||
response = requests.get(url, headers=self.get_headers(), params=params)
|
||||
return {"response": response}
|
||||
|
||||
def query_tasks(self, query: str) -> Dict:
|
||||
"""
|
||||
Query tasks that match certain fields
|
||||
"""
|
||||
params, error = load_query(query, fault_tolerant=True)
|
||||
if params is None:
|
||||
return {"Error": error}
|
||||
|
||||
url = f"{DEFAULT_URL}/list/{params['list_id']}/task"
|
||||
|
||||
params = self.get_default_params()
|
||||
response = requests.get(url, headers=self.get_headers(), params=params)
|
||||
|
||||
return {"response": response}
|
||||
|
||||
def get_spaces(self) -> Dict:
|
||||
"""
|
||||
Get all spaces for the team.
|
||||
"""
|
||||
url = f"{DEFAULT_URL}/team/{self.team_id}/space"
|
||||
response = requests.get(
|
||||
url, headers=self.get_headers(), params=self.get_default_params()
|
||||
)
|
||||
data = response.json()
|
||||
parsed_spaces = parse_dict_through_component(data, Space, fault_tolerant=True)
|
||||
return parsed_spaces
|
||||
|
||||
def get_task_attribute(self, query: str) -> Dict:
|
||||
"""
|
||||
Update an attribute of a specified task.
|
||||
"""
|
||||
|
||||
task = self.get_task(query, fault_tolerant=True)
|
||||
params, error = load_query(query, fault_tolerant=True)
|
||||
if not isinstance(params, dict):
|
||||
return {"Error": error}
|
||||
|
||||
if params["attribute_name"] not in task:
|
||||
return {
|
||||
"Error": f"""attribute_name = {params["attribute_name"]} was not
|
||||
found in task keys {task.keys()}. Please call again with one of the key names."""
|
||||
}
|
||||
|
||||
return {params["attribute_name"]: task[params["attribute_name"]]}
|
||||
|
||||
def update_task(self, query: str) -> Dict:
|
||||
"""
|
||||
Update an attribute of a specified task.
|
||||
"""
|
||||
query_dict, error = load_query(query, fault_tolerant=True)
|
||||
if query_dict is None:
|
||||
return {"Error": error}
|
||||
|
||||
url = f"{DEFAULT_URL}/task/{query_dict['task_id']}"
|
||||
params = {
|
||||
"custom_task_ids": "true",
|
||||
"team_id": self.team_id,
|
||||
"include_subtasks": "true",
|
||||
}
|
||||
headers = self.get_headers()
|
||||
payload = {query_dict["attribute_name"]: query_dict["value"]}
|
||||
|
||||
response = requests.put(url, headers=headers, params=params, json=payload)
|
||||
|
||||
return {"response": response}
|
||||
|
||||
def update_task_assignees(self, query: str) -> Dict:
|
||||
"""
|
||||
Add or remove assignees of a specified task.
|
||||
"""
|
||||
query_dict, error = load_query(query, fault_tolerant=True)
|
||||
if query_dict is None:
|
||||
return {"Error": error}
|
||||
|
||||
for user in query_dict["users"]:
|
||||
if not isinstance(user, int):
|
||||
return {
|
||||
"Error": f"""All users must be integers, not strings!
|
||||
"Got user {user} if type {type(user)}"""
|
||||
}
|
||||
|
||||
url = f"{DEFAULT_URL}/task/{query_dict['task_id']}"
|
||||
|
||||
headers = self.get_headers()
|
||||
|
||||
if query_dict["operation"] == "add":
|
||||
assigne_payload = {"add": query_dict["users"], "rem": []}
|
||||
elif query_dict["operation"] == "rem":
|
||||
assigne_payload = {"add": [], "rem": query_dict["users"]}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operation ({query_dict['operation']}). ",
|
||||
"Valid options ['add', 'rem'].",
|
||||
)
|
||||
|
||||
params = {
|
||||
"custom_task_ids": "true",
|
||||
"team_id": self.team_id,
|
||||
"include_subtasks": "true",
|
||||
}
|
||||
|
||||
payload = {"assignees": assigne_payload}
|
||||
response = requests.put(url, headers=headers, params=params, json=payload)
|
||||
return {"response": response}
|
||||
|
||||
def create_task(self, query: str) -> Dict:
|
||||
"""
|
||||
Creates a new task.
|
||||
"""
|
||||
query_dict, error = load_query(query, fault_tolerant=True)
|
||||
if query_dict is None:
|
||||
return {"Error": error}
|
||||
|
||||
list_id = self.list_id
|
||||
url = f"{DEFAULT_URL}/list/{list_id}/task"
|
||||
params = {"custom_task_ids": "true", "team_id": self.team_id}
|
||||
|
||||
payload = extract_dict_elements_from_component_fields(query_dict, Task)
|
||||
headers = self.get_headers()
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers, params=params)
|
||||
data: Dict = response.json()
|
||||
return parse_dict_through_component(data, Task, fault_tolerant=True)
|
||||
|
||||
def create_list(self, query: str) -> Dict:
|
||||
"""
|
||||
Creates a new list.
|
||||
"""
|
||||
query_dict, error = load_query(query, fault_tolerant=True)
|
||||
if query_dict is None:
|
||||
return {"Error": error}
|
||||
|
||||
# Default to using folder as location if it exists.
|
||||
# If not, fall back to using the space.
|
||||
location = self.folder_id if self.folder_id else self.space_id
|
||||
url = f"{DEFAULT_URL}/folder/{location}/list"
|
||||
|
||||
payload = extract_dict_elements_from_component_fields(query_dict, Task)
|
||||
headers = self.get_headers()
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
data = response.json()
|
||||
parsed_list = parse_dict_through_component(data, CUList, fault_tolerant=True)
|
||||
# set list id to new list
|
||||
if "id" in parsed_list:
|
||||
self.list_id = parsed_list["id"]
|
||||
return parsed_list
|
||||
|
||||
def create_folder(self, query: str) -> Dict:
|
||||
"""
|
||||
Creates a new folder.
|
||||
"""
|
||||
|
||||
query_dict, error = load_query(query, fault_tolerant=True)
|
||||
if query_dict is None:
|
||||
return {"Error": error}
|
||||
|
||||
space_id = self.space_id
|
||||
url = f"{DEFAULT_URL}/space/{space_id}/folder"
|
||||
payload = {
|
||||
"name": query_dict["name"],
|
||||
}
|
||||
|
||||
headers = self.get_headers()
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
if "id" in data:
|
||||
self.list_id = data["id"]
|
||||
return data
|
||||
|
||||
def run(self, mode: str, query: str) -> str:
|
||||
"""Run the API."""
|
||||
if mode == "get_task":
|
||||
output = self.get_task(query)
|
||||
elif mode == "get_task_attribute":
|
||||
output = self.get_task_attribute(query)
|
||||
elif mode == "get_teams":
|
||||
output = self.get_authorized_teams()
|
||||
elif mode == "create_task":
|
||||
output = self.create_task(query)
|
||||
elif mode == "create_list":
|
||||
output = self.create_list(query)
|
||||
elif mode == "create_folder":
|
||||
output = self.create_folder(query)
|
||||
elif mode == "get_lists":
|
||||
output = self.get_lists()
|
||||
elif mode == "get_folders":
|
||||
output = self.get_folders()
|
||||
elif mode == "get_spaces":
|
||||
output = self.get_spaces()
|
||||
elif mode == "update_task":
|
||||
output = self.update_task(query)
|
||||
elif mode == "update_task_assignees":
|
||||
output = self.update_task_assignees(query)
|
||||
else:
|
||||
output = {"ModeError": f"Got unexpected mode {mode}."}
|
||||
|
||||
try:
|
||||
return json.dumps(output)
|
||||
except Exception:
|
||||
return str(output)
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Utility that calls OpenAI's Dall-E Image Generator."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple, Union
|
||||
|
||||
from langchain_core.utils import (
|
||||
from_env,
|
||||
get_pydantic_field_names,
|
||||
secret_from_env,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DallEAPIWrapper(BaseModel):
|
||||
"""Wrapper for OpenAI's DALL-E Image Generator.
|
||||
|
||||
https://platform.openai.com/docs/guides/images/generations?context=node
|
||||
|
||||
Usage instructions:
|
||||
|
||||
1. `pip install openai`
|
||||
2. save your OPENAI_API_KEY in an environment variable
|
||||
"""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model_name: str = Field(default="dall-e-2", alias="model")
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
openai_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env(
|
||||
"OPENAI_API_KEY",
|
||||
default=None,
|
||||
),
|
||||
)
|
||||
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
|
||||
openai_api_base: Optional[str] = Field(
|
||||
alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
|
||||
)
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
openai_organization: Optional[str] = Field(
|
||||
alias="organization",
|
||||
default_factory=from_env(
|
||||
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
|
||||
),
|
||||
)
|
||||
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
|
||||
# to support explicit proxy for OpenAI
|
||||
openai_proxy: str = Field(default_factory=from_env("OPENAI_PROXY", default=""))
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
n: int = 1
|
||||
"""Number of images to generate"""
|
||||
size: str = "1024x1024"
|
||||
"""Size of image to generate"""
|
||||
separator: str = "\n"
|
||||
"""Separator to use when multiple URLs are returned."""
|
||||
quality: Optional[str] = None
|
||||
"""Quality of the image that will be generated"""
|
||||
max_retries: int = 2
|
||||
"""Maximum number of retries to make when generating."""
|
||||
default_headers: Union[Mapping[str, str], None] = None
|
||||
default_query: Union[Mapping[str, object], None] = None
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", protected_namespaces=())
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
logger.warning(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
|
||||
if is_openai_v1():
|
||||
client_params = {
|
||||
"api_key": self.openai_api_key.get_secret_value()
|
||||
if self.openai_api_key
|
||||
else None,
|
||||
"organization": self.openai_organization,
|
||||
"base_url": self.openai_api_base,
|
||||
"timeout": self.request_timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
"default_query": self.default_query,
|
||||
"http_client": self.http_client,
|
||||
}
|
||||
|
||||
if not self.client:
|
||||
self.client = openai.OpenAI(**client_params).images
|
||||
if not self.async_client:
|
||||
self.async_client = openai.AsyncOpenAI(**client_params).images
|
||||
elif not self.client:
|
||||
self.client = openai.Image
|
||||
else:
|
||||
pass
|
||||
return self
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Run query through OpenAI and parse result."""
|
||||
if is_openai_v1():
|
||||
kwargs = {
|
||||
"prompt": query,
|
||||
"n": self.n,
|
||||
"size": self.size,
|
||||
"model": self.model_name,
|
||||
}
|
||||
if self.quality is not None:
|
||||
kwargs["quality"] = self.quality
|
||||
response = self.client.generate(**kwargs)
|
||||
image_urls = self.separator.join([item.url for item in response.data])
|
||||
else:
|
||||
response = self.client.create(
|
||||
prompt=query, n=self.n, size=self.size, model=self.model_name
|
||||
)
|
||||
image_urls = self.separator.join([item["url"] for item in response["data"]])
|
||||
|
||||
return image_urls if image_urls else "No image was generated"
|
||||
@@ -0,0 +1,195 @@
|
||||
import base64
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class DataForSeoAPIWrapper(BaseModel):
|
||||
"""Wrapper around the DataForSeo API."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
default_params: dict = Field(
|
||||
default={
|
||||
"location_name": "United States",
|
||||
"language_code": "en",
|
||||
"depth": 10,
|
||||
"se_name": "google",
|
||||
"se_type": "organic",
|
||||
}
|
||||
)
|
||||
"""Default parameters to use for the DataForSEO SERP API."""
|
||||
params: dict = Field(default={})
|
||||
"""Additional parameters to pass to the DataForSEO SERP API."""
|
||||
api_login: Optional[str] = None
|
||||
"""The API login to use for the DataForSEO SERP API."""
|
||||
api_password: Optional[str] = None
|
||||
"""The API password to use for the DataForSEO SERP API."""
|
||||
json_result_types: Optional[list] = None
|
||||
"""The JSON result types."""
|
||||
json_result_fields: Optional[list] = None
|
||||
"""The JSON result fields."""
|
||||
top_count: Optional[int] = None
|
||||
"""The number of top results to return."""
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
"""The aiohttp session to use for the DataForSEO SERP API."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that login and password exists in environment."""
|
||||
login = get_from_dict_or_env(values, "api_login", "DATAFORSEO_LOGIN")
|
||||
password = get_from_dict_or_env(values, "api_password", "DATAFORSEO_PASSWORD")
|
||||
values["api_login"] = login
|
||||
values["api_password"] = password
|
||||
return values
|
||||
|
||||
async def arun(self, url: str) -> str:
|
||||
"""Run request to DataForSEO SERP API and parse result async."""
|
||||
return self._process_response(await self._aresponse_json(url))
|
||||
|
||||
def run(self, url: str) -> str:
|
||||
"""Run request to DataForSEO SERP API and parse result async."""
|
||||
return self._process_response(self._response_json(url))
|
||||
|
||||
def results(self, url: str) -> list:
|
||||
res = self._response_json(url)
|
||||
return self._filter_results(res)
|
||||
|
||||
async def aresults(self, url: str) -> list:
|
||||
res = await self._aresponse_json(url)
|
||||
return self._filter_results(res)
|
||||
|
||||
def _prepare_request(self, keyword: str) -> dict:
|
||||
"""Prepare the request details for the DataForSEO SERP API."""
|
||||
if self.api_login is None or self.api_password is None:
|
||||
raise ValueError("api_login or api_password is not provided")
|
||||
cred = base64.b64encode(
|
||||
f"{self.api_login}:{self.api_password}".encode("utf-8")
|
||||
).decode("utf-8")
|
||||
headers = {"Authorization": f"Basic {cred}", "Content-Type": "application/json"}
|
||||
obj = {"keyword": quote(keyword)}
|
||||
obj = {**obj, **self.default_params, **self.params}
|
||||
data = [obj]
|
||||
_url = (
|
||||
f"https://api.dataforseo.com/v3/serp/{obj['se_name']}"
|
||||
f"/{obj['se_type']}/live/advanced"
|
||||
)
|
||||
return {
|
||||
"url": _url,
|
||||
"headers": headers,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
def _check_response(self, response: dict) -> dict:
|
||||
"""Check the response from the DataForSEO SERP API for errors."""
|
||||
if response.get("status_code") != 20000:
|
||||
raise ValueError(
|
||||
f"Got error from DataForSEO SERP API: {response.get('status_message')}"
|
||||
)
|
||||
return response
|
||||
|
||||
def _response_json(self, url: str) -> dict:
|
||||
"""Use requests to run request to DataForSEO SERP API and return results."""
|
||||
request_details = self._prepare_request(url)
|
||||
response = requests.post(
|
||||
request_details["url"],
|
||||
headers=request_details["headers"],
|
||||
json=request_details["data"],
|
||||
)
|
||||
response.raise_for_status()
|
||||
return self._check_response(response.json())
|
||||
|
||||
async def _aresponse_json(self, url: str) -> dict:
|
||||
"""Use aiohttp to request DataForSEO SERP API and return results async."""
|
||||
request_details = self._prepare_request(url)
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
request_details["url"],
|
||||
headers=request_details["headers"],
|
||||
json=request_details["data"],
|
||||
) as response:
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.post(
|
||||
request_details["url"],
|
||||
headers=request_details["headers"],
|
||||
json=request_details["data"],
|
||||
) as response:
|
||||
res = await response.json()
|
||||
return self._check_response(res)
|
||||
|
||||
def _filter_results(self, res: dict) -> list:
|
||||
output = []
|
||||
types = self.json_result_types if self.json_result_types is not None else []
|
||||
for task in res.get("tasks", []):
|
||||
for result in task.get("result", []):
|
||||
for item in result.get("items", []):
|
||||
if len(types) == 0 or item.get("type", "") in types:
|
||||
self._cleanup_unnecessary_items(item)
|
||||
if len(item) != 0:
|
||||
output.append(item)
|
||||
if self.top_count is not None and len(output) >= self.top_count:
|
||||
break
|
||||
return output
|
||||
|
||||
def _cleanup_unnecessary_items(self, d: dict) -> dict:
|
||||
fields = self.json_result_fields if self.json_result_fields is not None else []
|
||||
if len(fields) > 0:
|
||||
for k, v in list(d.items()):
|
||||
if isinstance(v, dict):
|
||||
self._cleanup_unnecessary_items(v)
|
||||
if len(v) == 0:
|
||||
del d[k]
|
||||
elif k not in fields:
|
||||
del d[k]
|
||||
|
||||
if "xpath" in d:
|
||||
del d["xpath"]
|
||||
if "position" in d:
|
||||
del d["position"]
|
||||
if "rectangle" in d:
|
||||
del d["rectangle"]
|
||||
for k, v in list(d.items()):
|
||||
if isinstance(v, dict):
|
||||
self._cleanup_unnecessary_items(v)
|
||||
return d
|
||||
|
||||
def _process_response(self, res: dict) -> str:
|
||||
"""Process response from DataForSEO SERP API."""
|
||||
toret = "No good search result found"
|
||||
for task in res.get("tasks", []):
|
||||
for result in task.get("result", []):
|
||||
item_types = result.get("item_types")
|
||||
items = result.get("items", [])
|
||||
if "answer_box" in item_types:
|
||||
toret = next(
|
||||
item for item in items if item.get("type") == "answer_box"
|
||||
).get("text")
|
||||
elif "knowledge_graph" in item_types:
|
||||
toret = next(
|
||||
item for item in items if item.get("type") == "knowledge_graph"
|
||||
).get("description")
|
||||
elif "featured_snippet" in item_types:
|
||||
toret = next(
|
||||
item for item in items if item.get("type") == "featured_snippet"
|
||||
).get("description")
|
||||
elif "shopping" in item_types:
|
||||
toret = next(
|
||||
item for item in items if item.get("type") == "shopping"
|
||||
).get("price")
|
||||
elif "organic" in item_types:
|
||||
toret = next(
|
||||
item for item in items if item.get("type") == "organic"
|
||||
).get("description")
|
||||
if toret:
|
||||
break
|
||||
return toret
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Util that calls Dataherald."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
|
||||
class DataheraldAPIWrapper(BaseModel):
|
||||
"""Wrapper for Dataherald.
|
||||
|
||||
Docs for using:
|
||||
|
||||
1. Go to dataherald and sign up
|
||||
2. Create an API key
|
||||
3. Save your API key into DATAHERALD_API_KEY env variable
|
||||
4. pip install dataherald
|
||||
|
||||
"""
|
||||
|
||||
dataherald_client: Any = None #: :meta private:
|
||||
db_connection_id: str
|
||||
dataherald_api_key: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
dataherald_api_key = get_from_dict_or_env(
|
||||
values, "dataherald_api_key", "DATAHERALD_API_KEY"
|
||||
)
|
||||
values["dataherald_api_key"] = dataherald_api_key
|
||||
|
||||
try:
|
||||
import dataherald
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"dataherald is not installed. "
|
||||
"Please install it with `pip install dataherald`"
|
||||
)
|
||||
|
||||
client = dataherald.Dataherald(api_key=dataherald_api_key)
|
||||
values["dataherald_client"] = client
|
||||
|
||||
return values
|
||||
|
||||
def run(self, prompt: str) -> str:
|
||||
"""Generate a sql query through Dataherald and parse result."""
|
||||
from dataherald.types.sql_generation_create_params import Prompt
|
||||
|
||||
prompt_obj = Prompt(text=prompt, db_connection_id=self.db_connection_id)
|
||||
res = self.dataherald_client.sql_generations.create(prompt=prompt_obj)
|
||||
|
||||
try:
|
||||
answer = res.sql
|
||||
if not answer:
|
||||
# We don't want to return the assumption alone if answer is empty
|
||||
return "No answer"
|
||||
else:
|
||||
return f"Answer: {answer}"
|
||||
|
||||
except StopIteration:
|
||||
return "Dataherald wasn't able to answer it"
|
||||
@@ -0,0 +1,95 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DriaAPIWrapper:
|
||||
"""Wrapper around Dria API.
|
||||
|
||||
This wrapper facilitates interactions with Dria's vector search
|
||||
and retrieval services, including creating knowledge bases, inserting data,
|
||||
and fetching search results.
|
||||
|
||||
Attributes:
|
||||
api_key: Your API key for accessing Dria.
|
||||
contract_id: The contract ID of the knowledge base to interact with.
|
||||
top_n: Number of top results to fetch for a search.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, contract_id: Optional[str] = None, top_n: int = 10
|
||||
):
|
||||
try:
|
||||
from dria import Dria, Models
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"""Dria is not installed. Please install Dria to use this wrapper.
|
||||
|
||||
You can install Dria using the following command:
|
||||
pip install dria
|
||||
"""
|
||||
)
|
||||
return
|
||||
|
||||
self.api_key = api_key
|
||||
self.models = Models
|
||||
self.contract_id = contract_id
|
||||
self.top_n = top_n
|
||||
self.dria_client = Dria(api_key=self.api_key)
|
||||
if self.contract_id:
|
||||
self.dria_client.set_contract(self.contract_id)
|
||||
|
||||
def create_knowledge_base(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
category: str,
|
||||
embedding: str,
|
||||
) -> str:
|
||||
"""Create a new knowledge base."""
|
||||
contract_id = self.dria_client.create(
|
||||
name=name, embedding=embedding, category=category, description=description
|
||||
)
|
||||
logger.info(f"Knowledge base created with ID: {contract_id}")
|
||||
self.contract_id = contract_id
|
||||
return contract_id
|
||||
|
||||
def insert_data(self, data: List[Dict[str, Any]]) -> str:
|
||||
"""Insert data into the knowledge base."""
|
||||
response = self.dria_client.insert_text(data)
|
||||
logger.info(f"Data inserted: {response}")
|
||||
return response
|
||||
|
||||
def search(self, query: str) -> List[Dict[str, Any]]:
|
||||
"""Perform a text-based search."""
|
||||
results = self.dria_client.search(query, top_n=self.top_n)
|
||||
logger.info(f"Search results: {results}")
|
||||
return results
|
||||
|
||||
def query_with_vector(self, vector: List[float]) -> List[Dict[str, Any]]:
|
||||
"""Perform a vector-based query."""
|
||||
vector_query_results = self.dria_client.query(vector, top_n=self.top_n)
|
||||
logger.info(f"Vector query results: {vector_query_results}")
|
||||
return vector_query_results
|
||||
|
||||
def run(self, query: Union[str, List[float]]) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Method to handle both text-based searches and vector-based queries.
|
||||
|
||||
Args:
|
||||
query: A string for text-based search or a list of floats for
|
||||
vector-based query.
|
||||
|
||||
Returns:
|
||||
The search or query results from Dria.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return self.search(query)
|
||||
elif isinstance(query, list) and all(isinstance(item, float) for item in query):
|
||||
return self.query_with_vector(query)
|
||||
else:
|
||||
logger.error(
|
||||
"""Invalid query type. Please provide a string for text search or a
|
||||
list of floats for vector query."""
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,178 @@
|
||||
"""Util that calls DuckDuckGo Search.
|
||||
|
||||
No setup required. Free.
|
||||
https://pypi.org/project/duckduckgo-search/
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
|
||||
class DuckDuckGoSearchAPIWrapper(BaseModel):
|
||||
"""Wrapper for DuckDuckGo Search API.
|
||||
|
||||
Free and does not require any setup.
|
||||
"""
|
||||
|
||||
region: Optional[str] = "wt-wt"
|
||||
"""
|
||||
See https://pypi.org/project/duckduckgo-search/#regions
|
||||
"""
|
||||
safesearch: str = "moderate"
|
||||
"""
|
||||
Options: strict, moderate, off
|
||||
"""
|
||||
time: Optional[str] = "y"
|
||||
"""
|
||||
Options: d, w, m, y
|
||||
"""
|
||||
max_results: int = 5
|
||||
backend: str = "auto"
|
||||
"""
|
||||
Options: auto, html, lite
|
||||
"""
|
||||
source: str = "text"
|
||||
"""
|
||||
Options: text, news, images
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that python package exists in environment."""
|
||||
try:
|
||||
from ddgs import DDGS # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import ddgs python package. "
|
||||
"Please install it with `pip install -U ddgs`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _ddgs_text(
|
||||
self, query: str, max_results: Optional[int] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Run query through DuckDuckGo text search and return results."""
|
||||
from ddgs import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
ddgs_gen = ddgs.text(
|
||||
query,
|
||||
region=self.region,
|
||||
safesearch=self.safesearch,
|
||||
timelimit=self.time,
|
||||
max_results=max_results or self.max_results,
|
||||
backend=self.backend,
|
||||
)
|
||||
if ddgs_gen:
|
||||
return [r for r in ddgs_gen]
|
||||
return []
|
||||
|
||||
def _ddgs_news(
|
||||
self, query: str, max_results: Optional[int] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Run query through DuckDuckGo news search and return results."""
|
||||
from ddgs import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
ddgs_gen = ddgs.news(
|
||||
query,
|
||||
region=self.region,
|
||||
safesearch=self.safesearch,
|
||||
timelimit=self.time,
|
||||
max_results=max_results or self.max_results,
|
||||
)
|
||||
if ddgs_gen:
|
||||
return [r for r in ddgs_gen]
|
||||
return []
|
||||
|
||||
def _ddgs_images(
|
||||
self, query: str, max_results: Optional[int] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Run query through DuckDuckGo image search and return results."""
|
||||
from ddgs import DDGS
|
||||
|
||||
with DDGS() as ddgs:
|
||||
ddgs_gen = ddgs.images(
|
||||
query,
|
||||
region=self.region,
|
||||
safesearch=self.safesearch,
|
||||
max_results=max_results or self.max_results,
|
||||
)
|
||||
if ddgs_gen:
|
||||
return [r for r in ddgs_gen]
|
||||
return []
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Run query through DuckDuckGo and return concatenated results."""
|
||||
if self.source == "text":
|
||||
results = self._ddgs_text(query)
|
||||
elif self.source == "news":
|
||||
results = self._ddgs_news(query)
|
||||
elif self.source == "images":
|
||||
results = self._ddgs_images(query)
|
||||
else:
|
||||
results = []
|
||||
|
||||
if not results:
|
||||
return "No good DuckDuckGo Search Result was found"
|
||||
return " ".join(r["body"] for r in results)
|
||||
|
||||
def results(
|
||||
self, query: str, max_results: int, source: Optional[str] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Run query through DuckDuckGo and return metadata.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
max_results: The number of results to return.
|
||||
source: The source to look from.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries with the following keys:
|
||||
snippet - The description of the result.
|
||||
title - The title of the result.
|
||||
link - The link to the result.
|
||||
"""
|
||||
source = source or self.source
|
||||
if source == "text":
|
||||
results = [
|
||||
{"snippet": r["body"], "title": r["title"], "link": r["href"]}
|
||||
for r in self._ddgs_text(query, max_results=max_results)
|
||||
]
|
||||
elif source == "news":
|
||||
results = [
|
||||
{
|
||||
"snippet": r["body"],
|
||||
"title": r["title"],
|
||||
"link": r["url"],
|
||||
"date": r["date"],
|
||||
"source": r["source"],
|
||||
}
|
||||
for r in self._ddgs_news(query, max_results=max_results)
|
||||
]
|
||||
elif source == "images":
|
||||
results = [
|
||||
{
|
||||
"title": r["title"],
|
||||
"thumbnail": r["thumbnail"],
|
||||
"image": r["image"],
|
||||
"url": r["url"],
|
||||
"height": r["height"],
|
||||
"width": r["width"],
|
||||
"source": r["source"],
|
||||
}
|
||||
for r in self._ddgs_images(query, max_results=max_results)
|
||||
]
|
||||
else:
|
||||
results = []
|
||||
|
||||
if results is None:
|
||||
results = [{"Result": "No good DuckDuckGo Search Result was found"}]
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Util that calls several of financial datasets stock market REST APIs.
|
||||
Docs: https://docs.financialdatasets.ai/
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel
|
||||
|
||||
FINANCIAL_DATASETS_BASE_URL = "https://api.financialdatasets.ai/"
|
||||
|
||||
|
||||
class FinancialDatasetsAPIWrapper(BaseModel):
|
||||
"""Wrapper for financial datasets API."""
|
||||
|
||||
financial_datasets_api_key: Optional[str] = None
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
self.financial_datasets_api_key = get_from_dict_or_env(
|
||||
data, "financial_datasets_api_key", "FINANCIAL_DATASETS_API_KEY"
|
||||
)
|
||||
|
||||
@property
|
||||
def _api_key(self) -> str:
|
||||
if self.financial_datasets_api_key is None:
|
||||
raise ValueError(
|
||||
"API key is required for the FinancialDatasetsAPIWrapper. "
|
||||
"Please provide the API key by either:\n"
|
||||
"1. Manually specifying it when initializing the wrapper: "
|
||||
"FinancialDatasetsAPIWrapper(financial_datasets_api_key='your_api_key')\n"
|
||||
"2. Setting it as an environment variable: FINANCIAL_DATASETS_API_KEY"
|
||||
)
|
||||
return self.financial_datasets_api_key
|
||||
|
||||
def get_income_statements(
|
||||
self,
|
||||
ticker: str,
|
||||
period: str,
|
||||
limit: Optional[int],
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the income statements for a stock `ticker` over a `period` of time.
|
||||
|
||||
:param ticker: the stock ticker
|
||||
:param period: the period of time to get the balance sheets for.
|
||||
Possible values are: annual, quarterly, ttm.
|
||||
:param limit: the number of results to return, default is 10
|
||||
:return: a list of income statements
|
||||
"""
|
||||
url = (
|
||||
f"{FINANCIAL_DATASETS_BASE_URL}financials/income-statements/"
|
||||
f"?ticker={ticker}"
|
||||
f"&period={period}"
|
||||
f"&limit={limit if limit else 10}"
|
||||
)
|
||||
|
||||
# Add the api key to the headers
|
||||
headers = {"X-API-KEY": self._api_key}
|
||||
|
||||
# Execute the request
|
||||
response = requests.get(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
return data.get("income_statements", None)
|
||||
|
||||
def get_balance_sheets(
|
||||
self,
|
||||
ticker: str,
|
||||
period: str,
|
||||
limit: Optional[int],
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get the balance sheets for a stock `ticker` over a `period` of time.
|
||||
|
||||
:param ticker: the stock ticker
|
||||
:param period: the period of time to get the balance sheets for.
|
||||
Possible values are: annual, quarterly, ttm.
|
||||
:param limit: the number of results to return, default is 10
|
||||
:return: a list of balance sheets
|
||||
"""
|
||||
url = (
|
||||
f"{FINANCIAL_DATASETS_BASE_URL}financials/balance-sheets/"
|
||||
f"?ticker={ticker}"
|
||||
f"&period={period}"
|
||||
f"&limit={limit if limit else 10}"
|
||||
)
|
||||
|
||||
# Add the api key to the headers
|
||||
headers = {"X-API-KEY": self._api_key}
|
||||
|
||||
# Execute the request
|
||||
response = requests.get(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
return data.get("balance_sheets", None)
|
||||
|
||||
def get_cash_flow_statements(
|
||||
self,
|
||||
ticker: str,
|
||||
period: str,
|
||||
limit: Optional[int],
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get the cash flow statements for a stock `ticker` over a `period` of time.
|
||||
|
||||
:param ticker: the stock ticker
|
||||
:param period: the period of time to get the balance sheets for.
|
||||
Possible values are: annual, quarterly, ttm.
|
||||
:param limit: the number of results to return, default is 10
|
||||
:return: a list of cash flow statements
|
||||
"""
|
||||
|
||||
url = (
|
||||
f"{FINANCIAL_DATASETS_BASE_URL}financials/cash-flow-statements/"
|
||||
f"?ticker={ticker}"
|
||||
f"&period={period}"
|
||||
f"&limit={limit if limit else 10}"
|
||||
)
|
||||
|
||||
# Add the api key to the headers
|
||||
headers = {"X-API-KEY": self._api_key}
|
||||
|
||||
# Execute the request
|
||||
response = requests.get(url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
return data.get("cash_flow_statements", None)
|
||||
|
||||
def run(self, mode: str, ticker: str, **kwargs: Any) -> str:
|
||||
if mode == "get_income_statements":
|
||||
period = kwargs.get("period", "annual")
|
||||
limit = kwargs.get("limit", 10)
|
||||
return json.dumps(self.get_income_statements(ticker, period, limit))
|
||||
elif mode == "get_balance_sheets":
|
||||
period = kwargs.get("period", "annual")
|
||||
limit = kwargs.get("limit", 10)
|
||||
return json.dumps(self.get_balance_sheets(ticker, period, limit))
|
||||
elif mode == "get_cash_flow_statements":
|
||||
period = kwargs.get("period", "annual")
|
||||
limit = kwargs.get("limit", 10)
|
||||
return json.dumps(self.get_cash_flow_statements(ticker, period, limit))
|
||||
else:
|
||||
raise ValueError(f"Invalid mode {mode} for financial datasets API.")
|
||||
901
venv/Lib/site-packages/langchain_community/utilities/github.py
Normal file
901
venv/Lib/site-packages/langchain_community/utilities/github.py
Normal file
@@ -0,0 +1,901 @@
|
||||
"""Util that calls GitHub."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from github.Issue import Issue
|
||||
from github.PullRequest import PullRequest
|
||||
|
||||
|
||||
def _import_tiktoken() -> Any:
|
||||
"""Import tiktoken."""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"tiktoken is not installed. Please install it with `pip install tiktoken`"
|
||||
)
|
||||
return tiktoken
|
||||
|
||||
|
||||
class GitHubAPIWrapper(BaseModel):
|
||||
"""Wrapper for GitHub API."""
|
||||
|
||||
github: Any = None #: :meta private:
|
||||
github_repo_instance: Any = None #: :meta private:
|
||||
github_repository: Optional[str] = None
|
||||
github_app_id: Optional[str] = None
|
||||
github_app_private_key: Optional[str] = None
|
||||
active_branch: Optional[str] = None
|
||||
github_base_branch: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
github_repository = get_from_dict_or_env(
|
||||
values, "github_repository", "GITHUB_REPOSITORY"
|
||||
)
|
||||
|
||||
github_app_id = get_from_dict_or_env(values, "github_app_id", "GITHUB_APP_ID")
|
||||
|
||||
github_app_private_key = get_from_dict_or_env(
|
||||
values, "github_app_private_key", "GITHUB_APP_PRIVATE_KEY"
|
||||
)
|
||||
|
||||
try:
|
||||
from github import Auth, GithubIntegration
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PyGithub is not installed. "
|
||||
"Please install it with `pip install PyGithub`"
|
||||
)
|
||||
|
||||
try:
|
||||
# interpret the key as a file path
|
||||
# fallback to interpreting as the key itself
|
||||
with open(github_app_private_key, "r") as f:
|
||||
private_key = f.read()
|
||||
except Exception:
|
||||
private_key = github_app_private_key
|
||||
|
||||
auth = Auth.AppAuth(
|
||||
github_app_id,
|
||||
private_key,
|
||||
)
|
||||
gi = GithubIntegration(auth=auth)
|
||||
installation = gi.get_installations()
|
||||
if not installation:
|
||||
raise ValueError(
|
||||
f"Please make sure to install the created github app with id "
|
||||
f"{github_app_id} on the repo: {github_repository}"
|
||||
"More instructions can be found at "
|
||||
"https://docs.github.com/en/apps/using-"
|
||||
"github-apps/installing-your-own-github-app"
|
||||
)
|
||||
try:
|
||||
installation = installation[0]
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Please make sure to give correct github parameters Error message: {e}"
|
||||
)
|
||||
# create a GitHub instance:
|
||||
g = installation.get_github_for_installation()
|
||||
repo = g.get_repo(github_repository)
|
||||
|
||||
github_base_branch = get_from_dict_or_env(
|
||||
values,
|
||||
"github_base_branch",
|
||||
"GITHUB_BASE_BRANCH",
|
||||
default=repo.default_branch,
|
||||
)
|
||||
|
||||
active_branch = get_from_dict_or_env(
|
||||
values,
|
||||
"active_branch",
|
||||
"ACTIVE_BRANCH",
|
||||
default=repo.default_branch,
|
||||
)
|
||||
|
||||
values["github"] = g
|
||||
values["github_repo_instance"] = repo
|
||||
values["github_repository"] = github_repository
|
||||
values["github_app_id"] = github_app_id
|
||||
values["github_app_private_key"] = github_app_private_key
|
||||
values["active_branch"] = active_branch
|
||||
values["github_base_branch"] = github_base_branch
|
||||
|
||||
return values
|
||||
|
||||
def parse_issues(self, issues: List[Issue]) -> List[dict]:
|
||||
"""
|
||||
Extracts title and number from each Issue and puts them in a dictionary
|
||||
Parameters:
|
||||
issues(List[Issue]): A list of Github Issue objects
|
||||
Returns:
|
||||
List[dict]: A dictionary of issue titles and numbers
|
||||
"""
|
||||
parsed = []
|
||||
for issue in issues:
|
||||
title = issue.title
|
||||
number = issue.number
|
||||
opened_by = issue.user.login if issue.user else None
|
||||
issue_dict = {"title": title, "number": number}
|
||||
if opened_by is not None:
|
||||
issue_dict["opened_by"] = opened_by
|
||||
parsed.append(issue_dict)
|
||||
return parsed
|
||||
|
||||
def parse_pull_requests(self, pull_requests: List[PullRequest]) -> List[dict]:
|
||||
"""
|
||||
Extracts title and number from each Issue and puts them in a dictionary
|
||||
Parameters:
|
||||
issues(List[Issue]): A list of Github Issue objects
|
||||
Returns:
|
||||
List[dict]: A dictionary of issue titles and numbers
|
||||
"""
|
||||
parsed = []
|
||||
for pr in pull_requests:
|
||||
parsed.append(
|
||||
{
|
||||
"title": pr.title,
|
||||
"number": pr.number,
|
||||
"commits": str(pr.commits),
|
||||
"comments": str(pr.comments),
|
||||
}
|
||||
)
|
||||
return parsed
|
||||
|
||||
def get_issues(self) -> str:
|
||||
"""
|
||||
Fetches all open issues from the repo excluding pull requests
|
||||
|
||||
Returns:
|
||||
str: A plaintext report containing the number of issues
|
||||
and each issue's title and number.
|
||||
"""
|
||||
issues = self.github_repo_instance.get_issues(state="open")
|
||||
# Filter out pull requests (part of GH issues object)
|
||||
issues = [issue for issue in issues if not issue.pull_request]
|
||||
if issues:
|
||||
parsed_issues = self.parse_issues(issues)
|
||||
parsed_issues_str = (
|
||||
"Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues)
|
||||
)
|
||||
return parsed_issues_str
|
||||
else:
|
||||
return "No open issues available"
|
||||
|
||||
def list_open_pull_requests(self) -> str:
|
||||
"""
|
||||
Fetches all open PRs from the repo
|
||||
|
||||
Returns:
|
||||
str: A plaintext report containing the number of PRs
|
||||
and each PR's title and number.
|
||||
"""
|
||||
# issues = self.github_repo_instance.get_issues(state="open")
|
||||
pull_requests = self.github_repo_instance.get_pulls(state="open")
|
||||
if pull_requests.totalCount > 0:
|
||||
parsed_prs = self.parse_pull_requests(pull_requests)
|
||||
parsed_prs_str = (
|
||||
"Found " + str(len(parsed_prs)) + " pull requests:\n" + str(parsed_prs)
|
||||
)
|
||||
return parsed_prs_str
|
||||
else:
|
||||
return "No open pull requests available"
|
||||
|
||||
def list_files_in_main_branch(self) -> str:
|
||||
"""
|
||||
Fetches all files in the main branch of the repo.
|
||||
|
||||
Returns:
|
||||
str: A plaintext report containing the paths and names of the files.
|
||||
"""
|
||||
files: List[str] = []
|
||||
try:
|
||||
contents = self.github_repo_instance.get_contents(
|
||||
"", ref=self.github_base_branch
|
||||
)
|
||||
for content in contents:
|
||||
if content.type == "dir":
|
||||
files.extend(self._list_files(content.path))
|
||||
else:
|
||||
files.append(content.path)
|
||||
|
||||
if files:
|
||||
files_str = "\n".join(files)
|
||||
return f"Found {len(files)} files in the main branch:\n{files_str}"
|
||||
else:
|
||||
return "No files found in the main branch"
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
def set_active_branch(self, branch_name: str) -> str:
|
||||
"""Equivalent to `git checkout branch_name` for this Agent.
|
||||
Clones formatting from Github.
|
||||
|
||||
Returns an Error (as a string) if branch doesn't exist.
|
||||
"""
|
||||
curr_branches = [
|
||||
branch.name for branch in self.github_repo_instance.get_branches()
|
||||
]
|
||||
if branch_name in curr_branches:
|
||||
self.active_branch = branch_name
|
||||
return f"Switched to branch `{branch_name}`"
|
||||
else:
|
||||
return (
|
||||
f"Error {branch_name} does not exist,"
|
||||
f"in repo with current branches: {str(curr_branches)}"
|
||||
)
|
||||
|
||||
def list_branches_in_repo(self) -> str:
|
||||
"""
|
||||
Fetches a list of all branches in the repository.
|
||||
|
||||
Returns:
|
||||
str: A plaintext report containing the names of the branches.
|
||||
"""
|
||||
try:
|
||||
branches = [
|
||||
branch.name for branch in self.github_repo_instance.get_branches()
|
||||
]
|
||||
if branches:
|
||||
branches_str = "\n".join(branches)
|
||||
return (
|
||||
f"Found {len(branches)} branches in the repository:\n{branches_str}"
|
||||
)
|
||||
else:
|
||||
return "No branches found in the repository"
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
def create_branch(self, proposed_branch_name: str) -> str:
|
||||
"""
|
||||
Create a new branch, and set it as the active bot branch.
|
||||
Equivalent to `git switch -c proposed_branch_name`
|
||||
If the proposed branch already exists, we append _v1 then _v2...
|
||||
until a unique name is found.
|
||||
|
||||
Returns:
|
||||
str: A plaintext success message.
|
||||
"""
|
||||
from github import GithubException
|
||||
|
||||
i = 0
|
||||
new_branch_name = proposed_branch_name
|
||||
base_branch = self.github_repo_instance.get_branch(
|
||||
self.github_repo_instance.default_branch
|
||||
)
|
||||
for i in range(1000):
|
||||
try:
|
||||
self.github_repo_instance.create_git_ref(
|
||||
ref=f"refs/heads/{new_branch_name}", sha=base_branch.commit.sha
|
||||
)
|
||||
self.active_branch = new_branch_name
|
||||
return (
|
||||
f"Branch '{new_branch_name}' "
|
||||
"created successfully, and set as current active branch."
|
||||
)
|
||||
except GithubException as e:
|
||||
if e.status == 422 and "Reference already exists" in e.data["message"]:
|
||||
i += 1
|
||||
new_branch_name = f"{proposed_branch_name}_v{i}"
|
||||
else:
|
||||
# Handle any other exceptions
|
||||
print(f"Failed to create branch. Error: {e}") # noqa: T201
|
||||
raise Exception(
|
||||
"Unable to create branch name from proposed_branch_name: "
|
||||
f"{proposed_branch_name}"
|
||||
)
|
||||
return (
|
||||
"Unable to create branch. "
|
||||
"At least 1000 branches exist with named derived from "
|
||||
f"proposed_branch_name: `{proposed_branch_name}`"
|
||||
)
|
||||
|
||||
def list_files_in_bot_branch(self) -> str:
|
||||
"""
|
||||
Fetches all files in the active branch of the repo,
|
||||
the branch the bot uses to make changes.
|
||||
|
||||
Returns:
|
||||
str: A plaintext list containing the filepaths in the branch.
|
||||
"""
|
||||
files: List[str] = []
|
||||
try:
|
||||
contents = self.github_repo_instance.get_contents(
|
||||
"", ref=self.active_branch
|
||||
)
|
||||
for content in contents:
|
||||
if content.type == "dir":
|
||||
files.extend(self._list_files(content.path))
|
||||
else:
|
||||
files.append(content.path)
|
||||
|
||||
if files:
|
||||
files_str = "\n".join(files)
|
||||
return (
|
||||
f"Found {len(files)} files in branch `{self.active_branch}`:\n"
|
||||
f"{files_str}"
|
||||
)
|
||||
else:
|
||||
return f"No files found in branch: `{self.active_branch}`"
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
def get_files_from_directory(self, directory_path: str) -> str:
|
||||
"""
|
||||
Recursively fetches files from a directory in the repo.
|
||||
|
||||
Parameters:
|
||||
directory_path (str): Path to the directory
|
||||
|
||||
Returns:
|
||||
str: List of file paths, or an error message.
|
||||
"""
|
||||
from github import GithubException
|
||||
|
||||
try:
|
||||
return str(self._list_files(directory_path))
|
||||
except GithubException as e:
|
||||
return f"Error: status code {e.status}, {e.message}"
|
||||
|
||||
def _list_files(self, directory_path: str) -> List[str]:
|
||||
files: List[str] = []
|
||||
|
||||
contents = self.github_repo_instance.get_contents(
|
||||
directory_path, ref=self.active_branch
|
||||
)
|
||||
|
||||
for content in contents:
|
||||
if content.type == "dir":
|
||||
files.extend(self._list_files(content.path))
|
||||
else:
|
||||
files.append(content.path)
|
||||
return files
|
||||
|
||||
def get_issue(self, issue_number: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetches a specific issue and its first 10 comments
|
||||
Parameters:
|
||||
issue_number(int): The number for the github issue
|
||||
Returns:
|
||||
`dict` containing the issue's title, body, comments as a string, and the
|
||||
username of the user who opened the issue
|
||||
"""
|
||||
issue = self.github_repo_instance.get_issue(number=issue_number)
|
||||
page = 0
|
||||
comments: List[dict] = []
|
||||
while len(comments) <= 10:
|
||||
comments_page = issue.get_comments().get_page(page)
|
||||
if len(comments_page) == 0:
|
||||
break
|
||||
for comment in comments_page:
|
||||
comments.append({"body": comment.body, "user": comment.user.login})
|
||||
page += 1
|
||||
|
||||
opened_by = None
|
||||
if issue.user and issue.user.login:
|
||||
opened_by = issue.user.login
|
||||
|
||||
return {
|
||||
"number": issue_number,
|
||||
"title": issue.title,
|
||||
"body": issue.body,
|
||||
"comments": str(comments),
|
||||
"opened_by": str(opened_by),
|
||||
}
|
||||
|
||||
def list_pull_request_files(self, pr_number: int) -> List[Dict[str, Any]]:
|
||||
"""Fetches the full text of all files in a PR. Truncates after first 3k tokens.
|
||||
# TODO: Enhancement to summarize files with ctags if they're getting long.
|
||||
|
||||
Args:
|
||||
pr_number(int): The number of the pull request on Github
|
||||
|
||||
Returns:
|
||||
`dict` containing the issue's title, body, and comments as a string
|
||||
"""
|
||||
tiktoken = _import_tiktoken()
|
||||
MAX_TOKENS_FOR_FILES = 3_000
|
||||
pr_files = []
|
||||
pr = self.github_repo_instance.get_pull(number=int(pr_number))
|
||||
total_tokens = 0
|
||||
page = 0
|
||||
while True: # or while (total_tokens + tiktoken()) < MAX_TOKENS_FOR_FILES:
|
||||
files_page = pr.get_files().get_page(page)
|
||||
if len(files_page) == 0:
|
||||
break
|
||||
for file in files_page:
|
||||
try:
|
||||
file_metadata_response = requests.get(file.contents_url)
|
||||
if file_metadata_response.status_code == 200:
|
||||
download_url = json.loads(file_metadata_response.text)[
|
||||
"download_url"
|
||||
]
|
||||
else:
|
||||
print(f"Failed to download file: {file.contents_url}, skipping") # noqa: T201
|
||||
continue
|
||||
|
||||
file_content_response = requests.get(download_url)
|
||||
if file_content_response.status_code == 200:
|
||||
# Save the content as a UTF-8 string
|
||||
file_content = file_content_response.text
|
||||
else:
|
||||
print( # noqa: T201
|
||||
"Failed downloading file content "
|
||||
f"(Error {file_content_response.status_code}). Skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
file_tokens = len(
|
||||
tiktoken.get_encoding("cl100k_base").encode(
|
||||
file_content + file.filename + "file_name file_contents"
|
||||
)
|
||||
)
|
||||
if (total_tokens + file_tokens) < MAX_TOKENS_FOR_FILES:
|
||||
pr_files.append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"contents": file_content,
|
||||
"additions": file.additions,
|
||||
"deletions": file.deletions,
|
||||
}
|
||||
)
|
||||
total_tokens += file_tokens
|
||||
except Exception as e:
|
||||
print(f"Error when reading files from a PR on github. {e}") # noqa: T201
|
||||
page += 1
|
||||
return pr_files
|
||||
|
||||
def get_pull_request(self, pr_number: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetches a specific pull request and its first 10 comments,
|
||||
limited by max_tokens.
|
||||
|
||||
Parameters:
|
||||
pr_number(int): The number for the Github pull
|
||||
max_tokens(int): The maximum number of tokens in the response
|
||||
Returns:
|
||||
`dict` containing the pull's title, body, and comments as a string
|
||||
"""
|
||||
max_tokens = 2_000
|
||||
pull = self.github_repo_instance.get_pull(number=pr_number)
|
||||
total_tokens = 0
|
||||
|
||||
def get_tokens(text: str) -> int:
|
||||
tiktoken = _import_tiktoken()
|
||||
return len(tiktoken.get_encoding("cl100k_base").encode(text))
|
||||
|
||||
def add_to_dict(data_dict: Dict[str, Any], key: str, value: str) -> None:
|
||||
nonlocal total_tokens # Declare total_tokens as nonlocal
|
||||
tokens = get_tokens(value)
|
||||
if total_tokens + tokens <= max_tokens:
|
||||
data_dict[key] = value
|
||||
total_tokens += tokens # Now this will modify the outer variable
|
||||
|
||||
response_dict: Dict[str, str] = {}
|
||||
add_to_dict(response_dict, "title", pull.title)
|
||||
add_to_dict(response_dict, "number", str(pr_number))
|
||||
add_to_dict(response_dict, "body", pull.body if pull.body else "")
|
||||
|
||||
comments: List[str] = []
|
||||
page = 0
|
||||
while len(comments) <= 10:
|
||||
comments_page = pull.get_issue_comments().get_page(page)
|
||||
if len(comments_page) == 0:
|
||||
break
|
||||
for comment in comments_page:
|
||||
comment_str = str({"body": comment.body, "user": comment.user.login})
|
||||
if total_tokens + get_tokens(comment_str) > max_tokens:
|
||||
break
|
||||
comments.append(comment_str)
|
||||
total_tokens += get_tokens(comment_str)
|
||||
page += 1
|
||||
add_to_dict(response_dict, "comments", str(comments))
|
||||
|
||||
commits: List[str] = []
|
||||
page = 0
|
||||
while len(commits) <= 10:
|
||||
commits_page = pull.get_commits().get_page(page)
|
||||
if len(commits_page) == 0:
|
||||
break
|
||||
for commit in commits_page:
|
||||
commit_str = str({"message": commit.commit.message})
|
||||
if total_tokens + get_tokens(commit_str) > max_tokens:
|
||||
break
|
||||
commits.append(commit_str)
|
||||
total_tokens += get_tokens(commit_str)
|
||||
page += 1
|
||||
add_to_dict(response_dict, "commits", str(commits))
|
||||
return response_dict
|
||||
|
||||
def create_pull_request(self, pr_query: str) -> str:
|
||||
"""
|
||||
Makes a pull request from the bot's branch to the base branch
|
||||
Parameters:
|
||||
pr_query(str): a string which contains the PR title
|
||||
and the PR body. The title is the first line
|
||||
in the string, and the body are the rest of the string.
|
||||
For example, "Updated README\nmade changes to add info"
|
||||
Returns:
|
||||
str: A success or failure message
|
||||
"""
|
||||
if self.github_base_branch == self.active_branch:
|
||||
return """Cannot make a pull request because
|
||||
commits are already in the main or master branch."""
|
||||
else:
|
||||
try:
|
||||
title = pr_query.split("\n")[0]
|
||||
body = pr_query[len(title) + 2 :]
|
||||
pr = self.github_repo_instance.create_pull(
|
||||
title=title,
|
||||
body=body,
|
||||
head=self.active_branch,
|
||||
base=self.github_base_branch,
|
||||
)
|
||||
return f"Successfully created PR number {pr.number}"
|
||||
except Exception as e:
|
||||
return "Unable to make pull request due to error:\n" + str(e)
|
||||
|
||||
def comment_on_issue(self, comment_query: str) -> str:
|
||||
"""
|
||||
Adds a comment to a github issue
|
||||
Parameters:
|
||||
comment_query(str): a string which contains the issue number,
|
||||
two newlines, and the comment.
|
||||
for example: "1\n\nWorking on it now"
|
||||
adds the comment "working on it now" to issue 1
|
||||
Returns:
|
||||
str: A success or failure message
|
||||
"""
|
||||
issue_number = int(comment_query.split("\n\n")[0])
|
||||
comment = comment_query[len(str(issue_number)) + 2 :]
|
||||
try:
|
||||
issue = self.github_repo_instance.get_issue(number=issue_number)
|
||||
issue.create_comment(comment)
|
||||
return "Commented on issue " + str(issue_number)
|
||||
except Exception as e:
|
||||
return "Unable to make comment due to error:\n" + str(e)
|
||||
|
||||
def create_file(self, file_query: str) -> str:
|
||||
"""
|
||||
Creates a new file on the Github repo
|
||||
Parameters:
|
||||
file_query(str): a string which contains the file path
|
||||
and the file contents. The file path is the first line
|
||||
in the string, and the contents are the rest of the string.
|
||||
For example, "hello_world.md\n# Hello World!"
|
||||
Returns:
|
||||
str: A success or failure message
|
||||
"""
|
||||
if self.active_branch == self.github_base_branch:
|
||||
return (
|
||||
"You're attempting to commit to the directly to the"
|
||||
f"{self.github_base_branch} branch, which is protected. "
|
||||
"Please create a new branch and try again."
|
||||
)
|
||||
|
||||
file_path = file_query.split("\n")[0]
|
||||
file_contents = file_query[len(file_path) + 2 :]
|
||||
|
||||
try:
|
||||
try:
|
||||
file = self.github_repo_instance.get_contents(
|
||||
file_path, ref=self.active_branch
|
||||
)
|
||||
if file:
|
||||
return (
|
||||
f"File already exists at `{file_path}` "
|
||||
f"on branch `{self.active_branch}`. You must use "
|
||||
"`update_file` to modify it."
|
||||
)
|
||||
except Exception:
|
||||
# expected behavior, file shouldn't exist yet
|
||||
pass
|
||||
|
||||
self.github_repo_instance.create_file(
|
||||
path=file_path,
|
||||
message="Create " + file_path,
|
||||
content=file_contents,
|
||||
branch=self.active_branch,
|
||||
)
|
||||
return "Created file " + file_path
|
||||
except Exception as e:
|
||||
return "Unable to make file due to error:\n" + str(e)
|
||||
|
||||
def read_file(self, file_path: str) -> str:
|
||||
"""
|
||||
Read a file from this agent's branch, defined by self.active_branch,
|
||||
which supports PR branches.
|
||||
Parameters:
|
||||
file_path(str): the file path
|
||||
Returns:
|
||||
str: The file decoded as a string, or an error message if not found
|
||||
"""
|
||||
try:
|
||||
file = self.github_repo_instance.get_contents(
|
||||
file_path, ref=self.active_branch
|
||||
)
|
||||
return file.decoded_content.decode("utf-8")
|
||||
except Exception as e:
|
||||
return (
|
||||
f"File not found `{file_path}` on branch"
|
||||
f"`{self.active_branch}`. Error: {str(e)}"
|
||||
)
|
||||
|
||||
def update_file(self, file_query: str) -> str:
|
||||
"""
|
||||
Updates a file with new content.
|
||||
Parameters:
|
||||
file_query(str): Contains the file path and the file contents.
|
||||
The old file contents is wrapped in OLD <<<< and >>>> OLD
|
||||
The new file contents is wrapped in NEW <<<< and >>>> NEW
|
||||
For example:
|
||||
/test/hello.txt
|
||||
OLD <<<<
|
||||
Hello Earth!
|
||||
>>>> OLD
|
||||
NEW <<<<
|
||||
Hello Mars!
|
||||
>>>> NEW
|
||||
Returns:
|
||||
A success or failure message
|
||||
"""
|
||||
if self.active_branch == self.github_base_branch:
|
||||
return (
|
||||
"You're attempting to commit to the directly"
|
||||
f"to the {self.github_base_branch} branch, which is protected. "
|
||||
"Please create a new branch and try again."
|
||||
)
|
||||
try:
|
||||
file_path: str = file_query.split("\n")[0]
|
||||
old_file_contents = (
|
||||
file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip()
|
||||
)
|
||||
new_file_contents = (
|
||||
file_query.split("NEW <<<<")[1].split(">>>> NEW")[0].strip()
|
||||
)
|
||||
|
||||
file_content = self.read_file(file_path)
|
||||
updated_file_content = file_content.replace(
|
||||
old_file_contents, new_file_contents
|
||||
)
|
||||
|
||||
if file_content == updated_file_content:
|
||||
return (
|
||||
"File content was not updated because old content was not found."
|
||||
"It may be helpful to use the read_file action to get "
|
||||
"the current file contents."
|
||||
)
|
||||
|
||||
self.github_repo_instance.update_file(
|
||||
path=file_path,
|
||||
message="Update " + str(file_path),
|
||||
content=updated_file_content,
|
||||
branch=self.active_branch,
|
||||
sha=self.github_repo_instance.get_contents(
|
||||
file_path, ref=self.active_branch
|
||||
).sha,
|
||||
)
|
||||
return "Updated file " + str(file_path)
|
||||
except Exception as e:
|
||||
return "Unable to update file due to error:\n" + str(e)
|
||||
|
||||
def delete_file(self, file_path: str) -> str:
|
||||
"""
|
||||
Deletes a file from the repo
|
||||
Parameters:
|
||||
file_path(str): Where the file is
|
||||
Returns:
|
||||
str: Success or failure message
|
||||
"""
|
||||
if self.active_branch == self.github_base_branch:
|
||||
return (
|
||||
"You're attempting to commit to the directly"
|
||||
f"to the {self.github_base_branch} branch, which is protected. "
|
||||
"Please create a new branch and try again."
|
||||
)
|
||||
try:
|
||||
self.github_repo_instance.delete_file(
|
||||
path=file_path,
|
||||
message="Delete " + file_path,
|
||||
branch=self.active_branch,
|
||||
sha=self.github_repo_instance.get_contents(
|
||||
file_path, ref=self.active_branch
|
||||
).sha,
|
||||
)
|
||||
return "Deleted file " + file_path
|
||||
except Exception as e:
|
||||
return "Unable to delete file due to error:\n" + str(e)
|
||||
|
||||
def search_issues_and_prs(self, query: str) -> str:
|
||||
"""
|
||||
Searches issues and pull requests in the repository.
|
||||
|
||||
Parameters:
|
||||
query(str): The search query
|
||||
|
||||
Returns:
|
||||
str: A string containing the first 5 issues and pull requests
|
||||
"""
|
||||
search_result = self.github.search_issues(query, repo=self.github_repository)
|
||||
max_items = min(5, search_result.totalCount)
|
||||
results = [f"Top {max_items} results:"]
|
||||
for issue in search_result[:max_items]:
|
||||
results.append(
|
||||
f"Title: {issue.title}, Number: {issue.number}, State: {issue.state}"
|
||||
)
|
||||
return "\n".join(results)
|
||||
|
||||
def search_code(self, query: str) -> str:
|
||||
"""
|
||||
Searches code in the repository.
|
||||
# Todo: limit total tokens returned...
|
||||
|
||||
Parameters:
|
||||
query(str): The search query
|
||||
|
||||
Returns:
|
||||
str: A string containing, at most, the top 5 search results
|
||||
"""
|
||||
search_result = self.github.search_code(
|
||||
query=query, repo=self.github_repository
|
||||
)
|
||||
if search_result.totalCount == 0:
|
||||
return "0 results found."
|
||||
max_results = min(5, search_result.totalCount)
|
||||
results = [f"Showing top {max_results} of {search_result.totalCount} results:"]
|
||||
count = 0
|
||||
for code in search_result:
|
||||
if count >= max_results:
|
||||
break
|
||||
# Get the file content using the PyGithub get_contents method
|
||||
file_content = self.github_repo_instance.get_contents(
|
||||
code.path, ref=self.active_branch
|
||||
).decoded_content.decode()
|
||||
results.append(
|
||||
f"Filepath: `{code.path}`\nFile contents: {file_content}\n<END OF FILE>"
|
||||
)
|
||||
count += 1
|
||||
return "\n".join(results)
|
||||
|
||||
def create_review_request(self, reviewer_username: str) -> str:
|
||||
"""
|
||||
Creates a review request on *THE* open pull request
|
||||
that matches the current active_branch.
|
||||
|
||||
Parameters:
|
||||
reviewer_username(str): The username of the person who is being requested
|
||||
|
||||
Returns:
|
||||
str: A message confirming the creation of the review request
|
||||
"""
|
||||
pull_requests = self.github_repo_instance.get_pulls(
|
||||
state="open", sort="created"
|
||||
)
|
||||
# find PR against active_branch
|
||||
pr = next(
|
||||
(pr for pr in pull_requests if pr.head.ref == self.active_branch), None
|
||||
)
|
||||
if pr is None:
|
||||
return (
|
||||
"No open pull request found for the "
|
||||
f"current branch `{self.active_branch}`"
|
||||
)
|
||||
|
||||
try:
|
||||
pr.create_review_request(reviewers=[reviewer_username])
|
||||
return (
|
||||
f"Review request created for user {reviewer_username} "
|
||||
f"on PR #{pr.number}"
|
||||
)
|
||||
except Exception as e:
|
||||
return f"Failed to create a review request with error {e}"
|
||||
|
||||
def get_latest_release(self) -> str:
|
||||
"""
|
||||
Fetches the latest release of the repository.
|
||||
|
||||
Returns:
|
||||
str: The latest release
|
||||
"""
|
||||
release = self.github_repo_instance.get_latest_release()
|
||||
return (
|
||||
f"Latest title: {release.title} "
|
||||
f"tag: {release.tag_name} "
|
||||
f"body: {release.body}"
|
||||
)
|
||||
|
||||
def get_releases(self) -> str:
|
||||
"""
|
||||
Fetches all releases of the repository.
|
||||
|
||||
Returns:
|
||||
str: The releases
|
||||
"""
|
||||
releases = self.github_repo_instance.get_releases()
|
||||
max_results = min(5, releases.totalCount)
|
||||
results = [f"Top {max_results} results:"]
|
||||
for release in releases[:max_results]:
|
||||
results.append(
|
||||
f"Title: {release.title}, Tag: {release.tag_name}, Body: {release.body}"
|
||||
)
|
||||
|
||||
return "\n".join(results)
|
||||
|
||||
def get_release(self, tag_name: str) -> str:
|
||||
"""
|
||||
Fetches a specific release of the repository.
|
||||
|
||||
Parameters:
|
||||
tag_name(str): The tag name of the release
|
||||
|
||||
Returns:
|
||||
str: The release
|
||||
"""
|
||||
release = self.github_repo_instance.get_release(tag_name)
|
||||
return f"Release: {release.title} tag: {release.tag_name} body: {release.body}"
|
||||
|
||||
def run(self, mode: str, query: str) -> str:
|
||||
if mode == "get_issue":
|
||||
return json.dumps(self.get_issue(int(query)))
|
||||
elif mode == "get_pull_request":
|
||||
return json.dumps(self.get_pull_request(int(query)))
|
||||
elif mode == "list_pull_request_files":
|
||||
return json.dumps(self.list_pull_request_files(int(query)))
|
||||
elif mode == "get_issues":
|
||||
return self.get_issues()
|
||||
elif mode == "comment_on_issue":
|
||||
return self.comment_on_issue(query)
|
||||
elif mode == "create_file":
|
||||
return self.create_file(query)
|
||||
elif mode == "create_pull_request":
|
||||
return self.create_pull_request(query)
|
||||
elif mode == "read_file":
|
||||
return self.read_file(query)
|
||||
elif mode == "update_file":
|
||||
return self.update_file(query)
|
||||
elif mode == "delete_file":
|
||||
return self.delete_file(query)
|
||||
elif mode == "list_open_pull_requests":
|
||||
return self.list_open_pull_requests()
|
||||
elif mode == "list_files_in_main_branch":
|
||||
return self.list_files_in_main_branch()
|
||||
elif mode == "list_files_in_bot_branch":
|
||||
return self.list_files_in_bot_branch()
|
||||
elif mode == "list_branches_in_repo":
|
||||
return self.list_branches_in_repo()
|
||||
elif mode == "set_active_branch":
|
||||
return self.set_active_branch(query)
|
||||
elif mode == "create_branch":
|
||||
return self.create_branch(query)
|
||||
elif mode == "get_files_from_directory":
|
||||
return self.get_files_from_directory(query)
|
||||
elif mode == "search_issues_and_prs":
|
||||
return self.search_issues_and_prs(query)
|
||||
elif mode == "search_code":
|
||||
return self.search_code(query)
|
||||
elif mode == "create_review_request":
|
||||
return self.create_review_request(query)
|
||||
elif mode == "get_latest_release":
|
||||
return self.get_latest_release()
|
||||
elif mode == "get_releases":
|
||||
return self.get_releases()
|
||||
elif mode == "get_release":
|
||||
return self.get_release(query)
|
||||
else:
|
||||
raise ValueError("Invalid mode" + mode)
|
||||
518
venv/Lib/site-packages/langchain_community/utilities/gitlab.py
Normal file
518
venv/Lib/site-packages/langchain_community/utilities/gitlab.py
Normal file
@@ -0,0 +1,518 @@
|
||||
"""Util that calls gitlab."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gitlab.v4.objects import Issue
|
||||
|
||||
|
||||
class GitLabAPIWrapper(BaseModel):
|
||||
"""Wrapper for GitLab API."""
|
||||
|
||||
gitlab: Any = None #: :meta private:
|
||||
gitlab_repo_instance: Any = None #: :meta private:
|
||||
gitlab_url: Optional[str] = None
|
||||
"""The url of the GitLab instance."""
|
||||
gitlab_repository: Optional[str] = None
|
||||
"""The name of the GitLab repository, in the form {username}/{repo-name}."""
|
||||
gitlab_personal_access_token: Optional[str] = None
|
||||
"""Personal access token for the GitLab service, used for authentication."""
|
||||
gitlab_branch: Optional[str] = None
|
||||
"""The specific branch in the GitLab repository where the bot will make
|
||||
its commits. Defaults to 'main'.
|
||||
"""
|
||||
gitlab_base_branch: Optional[str] = None
|
||||
"""The base branch in the GitLab repository, used for comparisons.
|
||||
Usually 'main' or 'master'. Defaults to 'main'.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
gitlab_url = get_from_dict_or_env(
|
||||
values, "gitlab_url", "GITLAB_URL", default="https://gitlab.com"
|
||||
)
|
||||
gitlab_repository = get_from_dict_or_env(
|
||||
values, "gitlab_repository", "GITLAB_REPOSITORY"
|
||||
)
|
||||
|
||||
gitlab_personal_access_token = get_from_dict_or_env(
|
||||
values, "gitlab_personal_access_token", "GITLAB_PERSONAL_ACCESS_TOKEN"
|
||||
)
|
||||
|
||||
gitlab_branch = get_from_dict_or_env(
|
||||
values, "gitlab_branch", "GITLAB_BRANCH", default="main"
|
||||
)
|
||||
gitlab_base_branch = get_from_dict_or_env(
|
||||
values, "gitlab_base_branch", "GITLAB_BASE_BRANCH", default="main"
|
||||
)
|
||||
|
||||
try:
|
||||
import gitlab
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-gitlab is not installed. "
|
||||
"Please install it with `pip install python-gitlab`"
|
||||
)
|
||||
|
||||
g = gitlab.Gitlab(
|
||||
url=gitlab_url,
|
||||
private_token=gitlab_personal_access_token,
|
||||
keep_base_url=True,
|
||||
)
|
||||
|
||||
g.auth()
|
||||
|
||||
values["gitlab"] = g
|
||||
values["gitlab_repo_instance"] = g.projects.get(gitlab_repository)
|
||||
values["gitlab_url"] = gitlab_url
|
||||
values["gitlab_repository"] = gitlab_repository
|
||||
values["gitlab_personal_access_token"] = gitlab_personal_access_token
|
||||
values["gitlab_branch"] = gitlab_branch
|
||||
values["gitlab_base_branch"] = gitlab_base_branch
|
||||
|
||||
return values
|
||||
|
||||
def parse_issues(self, issues: List[Issue]) -> List[dict]:
|
||||
"""
|
||||
Extracts title and number from each Issue and puts them in a dictionary
|
||||
Parameters:
|
||||
issues(List[Issue]): A list of gitlab Issue objects
|
||||
Returns:
|
||||
List[dict]: A dictionary of issue titles and numbers
|
||||
"""
|
||||
parsed = []
|
||||
for issue in issues:
|
||||
title = issue.title
|
||||
number = issue.iid
|
||||
parsed.append({"title": title, "number": number})
|
||||
return parsed
|
||||
|
||||
def get_issues(self) -> str:
|
||||
"""
|
||||
Fetches all open issues from the repo
|
||||
|
||||
Returns:
|
||||
str: A plaintext report containing the number of issues
|
||||
and each issue's title and number.
|
||||
"""
|
||||
issues = self.gitlab_repo_instance.issues.list(state="opened")
|
||||
if len(issues) > 0:
|
||||
parsed_issues = self.parse_issues(issues)
|
||||
parsed_issues_str = (
|
||||
"Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues)
|
||||
)
|
||||
return parsed_issues_str
|
||||
else:
|
||||
return "No open issues available"
|
||||
|
||||
def get_issue(self, issue_number: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetches a specific issue and its first 10 comments
|
||||
Parameters:
|
||||
issue_number(int): The number for the gitlab issue
|
||||
Returns:
|
||||
`dict` containing the issue's title, body, and comments as a string
|
||||
"""
|
||||
issue = self.gitlab_repo_instance.issues.get(issue_number)
|
||||
page = 0
|
||||
comments: List[dict] = []
|
||||
while len(comments) <= 10:
|
||||
comments_page = issue.notes.list(page=page)
|
||||
if len(comments_page) == 0:
|
||||
break
|
||||
for comment in comments_page:
|
||||
comment = issue.notes.get(comment.id)
|
||||
comments.append(
|
||||
{
|
||||
"body": comment.body,
|
||||
"user": comment.author["username"],
|
||||
}
|
||||
)
|
||||
page += 1
|
||||
|
||||
return {
|
||||
"title": issue.title,
|
||||
"body": issue.description,
|
||||
"comments": str(comments),
|
||||
}
|
||||
|
||||
def create_pull_request(self, pr_query: str) -> str:
|
||||
"""
|
||||
Makes a pull request from the bot's branch to the base branch
|
||||
Parameters:
|
||||
pr_query(str): a string which contains the PR title
|
||||
and the PR body. The title is the first line
|
||||
in the string, and the body are the rest of the string.
|
||||
For example, "Updated README\nmade changes to add info"
|
||||
Returns:
|
||||
str: A success or failure message
|
||||
"""
|
||||
if self.gitlab_base_branch == self.gitlab_branch:
|
||||
return """Cannot make a pull request because
|
||||
commits are already in the master branch"""
|
||||
else:
|
||||
try:
|
||||
title = pr_query.split("\n")[0]
|
||||
body = pr_query[len(title) + 2 :]
|
||||
pr = self.gitlab_repo_instance.mergerequests.create(
|
||||
{
|
||||
"source_branch": self.gitlab_branch,
|
||||
"target_branch": self.gitlab_base_branch,
|
||||
"title": title,
|
||||
"description": body,
|
||||
"labels": ["created-by-agent"],
|
||||
}
|
||||
)
|
||||
return f"Successfully created PR number {pr.iid}"
|
||||
except Exception as e:
|
||||
return "Unable to make pull request due to error:\n" + str(e)
|
||||
|
||||
def comment_on_issue(self, comment_query: str) -> str:
|
||||
"""
|
||||
Adds a comment to a gitlab issue
|
||||
Parameters:
|
||||
comment_query(str): a string which contains the issue number,
|
||||
two newlines, and the comment.
|
||||
for example: "1\n\nWorking on it now"
|
||||
adds the comment "working on it now" to issue 1
|
||||
Returns:
|
||||
str: A success or failure message
|
||||
"""
|
||||
issue_number = int(comment_query.split("\n\n")[0])
|
||||
comment = comment_query[len(str(issue_number)) + 2 :]
|
||||
try:
|
||||
issue = self.gitlab_repo_instance.issues.get(issue_number)
|
||||
issue.notes.create({"body": comment})
|
||||
return "Commented on issue " + str(issue_number)
|
||||
except Exception as e:
|
||||
return "Unable to make comment due to error:\n" + str(e)
|
||||
|
||||
def create_file(self, file_query: str) -> str:
|
||||
"""
|
||||
Creates a new file on the gitlab repo
|
||||
Parameters:
|
||||
file_query(str): a string which contains the file path
|
||||
and the file contents. The file path is the first line
|
||||
in the string, and the contents are the rest of the string.
|
||||
For example, "hello_world.md\n# Hello World!"
|
||||
Returns:
|
||||
str: A success or failure message
|
||||
"""
|
||||
if self.gitlab_branch == self.gitlab_base_branch:
|
||||
return (
|
||||
"You're attempting to commit directly"
|
||||
f"to the {self.gitlab_base_branch} branch, which is protected. "
|
||||
"Please create a new branch and try again."
|
||||
)
|
||||
file_path = file_query.split("\n")[0]
|
||||
file_contents = file_query[len(file_path) + 2 :]
|
||||
try:
|
||||
self.gitlab_repo_instance.files.get(file_path, self.gitlab_branch)
|
||||
return f"File already exists at {file_path}. Use update_file instead"
|
||||
except Exception:
|
||||
data = {
|
||||
"branch": self.gitlab_branch,
|
||||
"commit_message": "Create " + file_path,
|
||||
"file_path": file_path,
|
||||
"content": file_contents,
|
||||
}
|
||||
|
||||
self.gitlab_repo_instance.files.create(data)
|
||||
|
||||
return "Created file " + file_path
|
||||
|
||||
def read_file(self, file_path: str) -> str:
|
||||
"""
|
||||
Reads a file from the gitlab repo
|
||||
Parameters:
|
||||
file_path(str): the file path
|
||||
Returns:
|
||||
str: The file decoded as a string
|
||||
"""
|
||||
file = self.gitlab_repo_instance.files.get(file_path, self.gitlab_branch)
|
||||
return file.decode().decode("utf-8")
|
||||
|
||||
def update_file(self, file_query: str) -> str:
|
||||
"""
|
||||
Updates a file with new content.
|
||||
Parameters:
|
||||
file_query(str): Contains the file path and the file contents.
|
||||
The old file contents is wrapped in OLD <<<< and >>>> OLD
|
||||
The new file contents is wrapped in NEW <<<< and >>>> NEW
|
||||
For example:
|
||||
test/hello.txt
|
||||
OLD <<<<
|
||||
Hello Earth!
|
||||
>>>> OLD
|
||||
NEW <<<<
|
||||
Hello Mars!
|
||||
>>>> NEW
|
||||
Returns:
|
||||
A success or failure message
|
||||
"""
|
||||
if self.gitlab_branch == self.gitlab_base_branch:
|
||||
return (
|
||||
"You're attempting to commit directly"
|
||||
f"to the {self.gitlab_base_branch} branch, which is protected. "
|
||||
"Please create a new branch and try again."
|
||||
)
|
||||
try:
|
||||
file_path = file_query.split("\n")[0]
|
||||
old_file_contents = (
|
||||
file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip()
|
||||
)
|
||||
new_file_contents = (
|
||||
file_query.split("NEW <<<<")[1].split(">>>> NEW")[0].strip()
|
||||
)
|
||||
|
||||
file_content = self.read_file(file_path)
|
||||
updated_file_content = file_content.replace(
|
||||
old_file_contents, new_file_contents
|
||||
)
|
||||
|
||||
if file_content == updated_file_content:
|
||||
return (
|
||||
"File content was not updated because old content was not found."
|
||||
"It may be helpful to use the read_file action to get "
|
||||
"the current file contents."
|
||||
)
|
||||
|
||||
commit = {
|
||||
"branch": self.gitlab_branch,
|
||||
"commit_message": "Create " + file_path,
|
||||
"actions": [
|
||||
{
|
||||
"action": "update",
|
||||
"file_path": file_path,
|
||||
"content": updated_file_content,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
self.gitlab_repo_instance.commits.create(commit)
|
||||
return "Updated file " + file_path
|
||||
except Exception as e:
|
||||
return "Unable to update file due to error:\n" + str(e)
|
||||
|
||||
def delete_file(self, file_path: str) -> str:
|
||||
"""
|
||||
Deletes a file from the repo
|
||||
Parameters:
|
||||
file_path(str): Where the file is
|
||||
Returns:
|
||||
str: Success or failure message
|
||||
"""
|
||||
if self.gitlab_branch == self.gitlab_base_branch:
|
||||
return (
|
||||
"You're attempting to commit directly"
|
||||
f"to the {self.gitlab_base_branch} branch, which is protected. "
|
||||
"Please create a new branch and try again."
|
||||
)
|
||||
try:
|
||||
self.gitlab_repo_instance.files.delete(
|
||||
file_path, self.gitlab_branch, "Delete " + file_path
|
||||
)
|
||||
return "Deleted file " + file_path
|
||||
except Exception as e:
|
||||
return "Unable to delete file due to error:\n" + str(e)
|
||||
|
||||
def list_files_in_main_branch(self) -> str:
|
||||
"""
|
||||
Get the list of files in the main branch of the repository
|
||||
|
||||
Returns:
|
||||
str: A plaintext report containing the list of files
|
||||
in the repository in the main branch
|
||||
"""
|
||||
if self.gitlab_base_branch is None:
|
||||
return "No base branch set. Please set a base branch."
|
||||
return self._list_files(self.gitlab_base_branch)
|
||||
|
||||
def list_files_in_bot_branch(self) -> str:
|
||||
"""
|
||||
Get the list of files in the active branch of the repository
|
||||
|
||||
Returns:
|
||||
str: A plaintext report containing the list of files
|
||||
in the repository in the active branch
|
||||
"""
|
||||
if self.gitlab_branch is None:
|
||||
return "No active branch set. Please set a branch."
|
||||
return self._list_files(self.gitlab_branch)
|
||||
|
||||
def list_files_from_directory(self, path: str) -> str:
|
||||
"""
|
||||
Get the list of files in the active branch of the repository
|
||||
from a specific directory
|
||||
|
||||
Returns:
|
||||
str: A plaintext report containing the list of files
|
||||
in the repository in the active branch from the specified directory
|
||||
"""
|
||||
if self.gitlab_branch is None:
|
||||
return "No active branch set. Please set a branch."
|
||||
return self._list_files(
|
||||
branch=self.gitlab_branch,
|
||||
path=path,
|
||||
)
|
||||
|
||||
def _list_files(self, branch: str, path: str = "") -> str:
|
||||
try:
|
||||
files = self._get_repository_files(
|
||||
branch=branch,
|
||||
path=path,
|
||||
)
|
||||
if files:
|
||||
files_str = "\n".join(files)
|
||||
return f"Found {len(files)} files in branch `{branch}`:\n{files_str}"
|
||||
else:
|
||||
return f"No files found in branch: `{branch}`"
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
def _get_repository_files(self, branch: str, path: str = "") -> List[str]:
|
||||
repo_contents = self.gitlab_repo_instance.repository_tree(ref=branch, path=path)
|
||||
|
||||
files: List[str] = []
|
||||
for content in repo_contents:
|
||||
if content["type"] == "tree":
|
||||
files.extend(self._get_repository_files(branch, content["path"]))
|
||||
else:
|
||||
files.append(content["path"])
|
||||
|
||||
return files
|
||||
|
||||
def create_branch(self, proposed_branch_name: str) -> str:
|
||||
"""
|
||||
Create a new branch in the repository and set it as the active branch
|
||||
|
||||
Parameters:
|
||||
proposed_branch_name (str): The name of the new branch to be created
|
||||
Returns:
|
||||
str: A success or failure message
|
||||
"""
|
||||
from gitlab import GitlabCreateError
|
||||
|
||||
max_attempts = 100
|
||||
new_branch_name = proposed_branch_name
|
||||
for i in range(max_attempts):
|
||||
try:
|
||||
response = self.gitlab_repo_instance.branches.create(
|
||||
{
|
||||
"branch": new_branch_name,
|
||||
"ref": self.gitlab_branch,
|
||||
}
|
||||
)
|
||||
|
||||
self.gitlab_branch = response.name
|
||||
return (
|
||||
f"Branch '{response.name}' "
|
||||
"created successfully, and set as current active branch."
|
||||
)
|
||||
|
||||
except GitlabCreateError as e:
|
||||
if (
|
||||
e.response_code == 400
|
||||
and "Branch already exists" in e.error_message
|
||||
):
|
||||
i += 1
|
||||
new_branch_name = f"{proposed_branch_name}_v{i}"
|
||||
else:
|
||||
# Handle any other exceptions
|
||||
print(f"Failed to create branch. Error: {e}") # noqa: T201
|
||||
raise Exception(
|
||||
"Unable to create branch name from proposed_branch_name: "
|
||||
f"{proposed_branch_name}"
|
||||
)
|
||||
|
||||
return (
|
||||
f"Unable to create branch. At least {max_attempts} branches exist "
|
||||
f"with named derived from "
|
||||
f"proposed_branch_name: `{proposed_branch_name}`"
|
||||
)
|
||||
|
||||
def list_branches_in_repo(self) -> str:
|
||||
"""
|
||||
Get the list of branches in the repository
|
||||
|
||||
Returns:
|
||||
str: A plaintext report containing the number of branches
|
||||
and each branch name
|
||||
"""
|
||||
branches = [
|
||||
branch.name for branch in self.gitlab_repo_instance.branches.list(all=True)
|
||||
]
|
||||
if branches:
|
||||
branches_str = "\n".join(branches)
|
||||
return (
|
||||
f"Found {str(len(branches))} branches in the repository:"
|
||||
f"\n{branches_str}"
|
||||
)
|
||||
return "No branches found in the repository"
|
||||
|
||||
def set_active_branch(self, branch_name: str) -> str:
|
||||
"""Equivalent to `git checkout branch_name` for this Agent.
|
||||
Clones formatting from Gitlab.
|
||||
|
||||
Returns an Error (as a string) if branch doesn't exist.
|
||||
"""
|
||||
curr_branches = [
|
||||
branch.name
|
||||
for branch in self.gitlab_repo_instance.branches.list(
|
||||
all=True,
|
||||
)
|
||||
]
|
||||
if branch_name in curr_branches:
|
||||
self.gitlab_branch = branch_name
|
||||
return f"Switched to branch `{branch_name}`"
|
||||
else:
|
||||
return (
|
||||
f"Error {branch_name} does not exist,"
|
||||
f"in repo with current branches: {str(curr_branches)}"
|
||||
)
|
||||
|
||||
def run(self, mode: str, query: str) -> str:
|
||||
if mode == "get_issues":
|
||||
return self.get_issues()
|
||||
elif mode == "get_issue":
|
||||
return json.dumps(self.get_issue(int(query)))
|
||||
elif mode == "comment_on_issue":
|
||||
return self.comment_on_issue(query)
|
||||
elif mode == "create_file":
|
||||
return self.create_file(query)
|
||||
elif mode == "create_pull_request":
|
||||
return self.create_pull_request(query)
|
||||
elif mode == "read_file":
|
||||
return self.read_file(query)
|
||||
elif mode == "update_file":
|
||||
return self.update_file(query)
|
||||
elif mode == "delete_file":
|
||||
return self.delete_file(query)
|
||||
elif mode == "create_branch":
|
||||
return self.create_branch(query)
|
||||
elif mode == "list_branches_in_repo":
|
||||
return self.list_branches_in_repo()
|
||||
elif mode == "set_active_branch":
|
||||
return self.set_active_branch(query)
|
||||
elif mode == "list_files_in_main_branch":
|
||||
return self.list_files_in_main_branch()
|
||||
elif mode == "list_files_in_bot_branch":
|
||||
return self.list_files_in_bot_branch()
|
||||
elif mode == "list_files_from_directory":
|
||||
return self.list_files_from_directory(query)
|
||||
else:
|
||||
raise ValueError("Invalid mode" + mode)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user