initial commit
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
"""Amadeus tools."""
|
||||
|
||||
from langchain_community.tools.amadeus.closest_airport import AmadeusClosestAirport
|
||||
from langchain_community.tools.amadeus.flight_search import AmadeusFlightSearch
|
||||
|
||||
__all__ = [
|
||||
"AmadeusClosestAirport",
|
||||
"AmadeusFlightSearch",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,19 @@
|
||||
"""Base class for Amadeus tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import Field
|
||||
|
||||
from langchain_community.tools.amadeus.utils import authenticate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amadeus import Client
|
||||
|
||||
|
||||
class AmadeusBaseTool(BaseTool):
|
||||
"""Base Tool for Amadeus."""
|
||||
|
||||
client: Client = Field(default_factory=authenticate)
|
||||
@@ -0,0 +1,62 @@
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_community.tools.amadeus.base import AmadeusBaseTool
|
||||
|
||||
|
||||
class ClosestAirportSchema(BaseModel):
|
||||
"""Schema for the AmadeusClosestAirport tool."""
|
||||
|
||||
location: str = Field(
|
||||
description=(
|
||||
" The location for which you would like to find the nearest airport "
|
||||
" along with optional details such as country, state, region, or "
|
||||
" province, allowing for easy processing and identification of "
|
||||
" the closest airport. Examples of the format are the following:\n"
|
||||
" Cali, Colombia\n "
|
||||
" Lincoln, Nebraska, United States\n"
|
||||
" New York, United States\n"
|
||||
" Sydney, New South Wales, Australia\n"
|
||||
" Rome, Lazio, Italy\n"
|
||||
" Toronto, Ontario, Canada\n"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AmadeusClosestAirport(AmadeusBaseTool):
|
||||
"""Tool for finding the closest airport to a particular location."""
|
||||
|
||||
name: str = "closest_airport"
|
||||
description: str = (
|
||||
"Use this tool to find the closest airport to a particular location."
|
||||
)
|
||||
args_schema: Type[ClosestAirportSchema] = ClosestAirportSchema
|
||||
|
||||
llm: Optional[BaseLanguageModel] = Field(default=None)
|
||||
"""Tool's llm used for calculating the closest airport. Defaults to `ChatOpenAI`."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_llm(cls, values: Dict[str, Any]) -> Any:
|
||||
if not values.get("llm"):
|
||||
# For backward-compatibility
|
||||
values["llm"] = ChatOpenAI(temperature=0)
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self,
|
||||
location: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
content = (
|
||||
f" What is the nearest airport to {location}? Please respond with the "
|
||||
" airport's International Air Transport Association (IATA) Location "
|
||||
' Identifier in the following JSON format. JSON: "iataCode": "IATA '
|
||||
' Location Identifier" '
|
||||
)
|
||||
|
||||
return self.llm.invoke(content) # type: ignore[union-attr]
|
||||
@@ -0,0 +1,153 @@
|
||||
import logging
|
||||
from datetime import datetime as dt
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_community.tools.amadeus.base import AmadeusBaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FlightSearchSchema(BaseModel):
|
||||
"""Schema for the AmadeusFlightSearch tool."""
|
||||
|
||||
originLocationCode: str = Field(
|
||||
description=(
|
||||
" The three letter International Air Transport "
|
||||
" Association (IATA) Location Identifier for the "
|
||||
" search's origin airport. "
|
||||
)
|
||||
)
|
||||
destinationLocationCode: str = Field(
|
||||
description=(
|
||||
" The three letter International Air Transport "
|
||||
" Association (IATA) Location Identifier for the "
|
||||
" search's destination airport. "
|
||||
)
|
||||
)
|
||||
departureDateTimeEarliest: str = Field(
|
||||
description=(
|
||||
" The earliest departure datetime from the origin airport "
|
||||
" for the flight search in the following format: "
|
||||
' "YYYY-MM-DDTHH:MM:SS", where "T" separates the date and time '
|
||||
' components. For example: "2023-06-09T10:30:00" represents '
|
||||
" June 9th, 2023, at 10:30 AM. "
|
||||
)
|
||||
)
|
||||
departureDateTimeLatest: str = Field(
|
||||
description=(
|
||||
" The latest departure datetime from the origin airport "
|
||||
" for the flight search in the following format: "
|
||||
' "YYYY-MM-DDTHH:MM:SS", where "T" separates the date and time '
|
||||
' components. For example: "2023-06-09T10:30:00" represents '
|
||||
" June 9th, 2023, at 10:30 AM. "
|
||||
)
|
||||
)
|
||||
page_number: int = Field(
|
||||
default=1,
|
||||
description="The specific page number of flight results to retrieve",
|
||||
)
|
||||
|
||||
|
||||
class AmadeusFlightSearch(AmadeusBaseTool):
|
||||
"""Tool for searching for a single flight between two airports."""
|
||||
|
||||
name: str = "single_flight_search"
|
||||
description: str = (
|
||||
" Use this tool to search for a single flight between the origin and "
|
||||
" destination airports at a departure between an earliest and "
|
||||
" latest datetime. "
|
||||
)
|
||||
args_schema: Type[FlightSearchSchema] = FlightSearchSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
originLocationCode: str,
|
||||
destinationLocationCode: str,
|
||||
departureDateTimeEarliest: str,
|
||||
departureDateTimeLatest: str,
|
||||
page_number: int = 1,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> list:
|
||||
try:
|
||||
from amadeus import ResponseError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import amadeus, please install with `pip install amadeus`."
|
||||
) from e
|
||||
|
||||
RESULTS_PER_PAGE = 10
|
||||
|
||||
# Authenticate and retrieve a client
|
||||
client = self.client
|
||||
|
||||
# Check that earliest and latest dates are in the same day
|
||||
earliestDeparture = dt.strptime(departureDateTimeEarliest, "%Y-%m-%dT%H:%M:%S")
|
||||
latestDeparture = dt.strptime(departureDateTimeLatest, "%Y-%m-%dT%H:%M:%S")
|
||||
|
||||
if earliestDeparture.date() != latestDeparture.date():
|
||||
logger.error(
|
||||
" Error: Earliest and latest departure dates need to be the "
|
||||
" same date. If you're trying to search for round-trip "
|
||||
" flights, call this function for the outbound flight first, "
|
||||
" and then call again for the return flight. "
|
||||
)
|
||||
return [None]
|
||||
|
||||
# Collect all results from the Amadeus Flight Offers Search API
|
||||
response = None
|
||||
try:
|
||||
response = client.shopping.flight_offers_search.get(
|
||||
originLocationCode=originLocationCode,
|
||||
destinationLocationCode=destinationLocationCode,
|
||||
departureDate=latestDeparture.strftime("%Y-%m-%d"),
|
||||
adults=1,
|
||||
)
|
||||
except ResponseError as error:
|
||||
print(error) # noqa: T201
|
||||
|
||||
# Generate output dictionary
|
||||
output = []
|
||||
if response is not None:
|
||||
for offer in response.data:
|
||||
itinerary: Dict = {}
|
||||
itinerary["price"] = {}
|
||||
itinerary["price"]["total"] = offer["price"]["total"]
|
||||
currency = offer["price"]["currency"]
|
||||
currency = response.result["dictionaries"]["currencies"][currency]
|
||||
itinerary["price"]["currency"] = {}
|
||||
itinerary["price"]["currency"] = currency
|
||||
|
||||
segments = []
|
||||
for segment in offer["itineraries"][0]["segments"]:
|
||||
flight = {}
|
||||
flight["departure"] = segment["departure"]
|
||||
flight["arrival"] = segment["arrival"]
|
||||
flight["flightNumber"] = segment["number"]
|
||||
carrier = segment["carrierCode"]
|
||||
carrier = response.result["dictionaries"]["carriers"][carrier]
|
||||
flight["carrier"] = carrier
|
||||
|
||||
segments.append(flight)
|
||||
|
||||
itinerary["segments"] = []
|
||||
itinerary["segments"] = segments
|
||||
|
||||
output.append(itinerary)
|
||||
|
||||
# Filter out flights after latest departure time
|
||||
for index, offer in enumerate(output):
|
||||
offerDeparture = dt.strptime(
|
||||
offer["segments"][0]["departure"]["at"], "%Y-%m-%dT%H:%M:%S"
|
||||
)
|
||||
|
||||
if offerDeparture > latestDeparture:
|
||||
output.pop(index)
|
||||
|
||||
# Return the paginated results
|
||||
startIndex = (page_number - 1) * RESULTS_PER_PAGE
|
||||
endIndex = startIndex + RESULTS_PER_PAGE
|
||||
|
||||
return output[startIndex:endIndex]
|
||||
@@ -0,0 +1,43 @@
|
||||
"""O365 tool utils."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from amadeus import Client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def authenticate() -> Client:
|
||||
"""Authenticate using the Amadeus API"""
|
||||
try:
|
||||
from amadeus import Client
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cannot import amadeus. Please install the package with "
|
||||
"`pip install amadeus`."
|
||||
) from e
|
||||
|
||||
if "AMADEUS_CLIENT_ID" in os.environ and "AMADEUS_CLIENT_SECRET" in os.environ:
|
||||
client_id = os.environ["AMADEUS_CLIENT_ID"]
|
||||
client_secret = os.environ["AMADEUS_CLIENT_SECRET"]
|
||||
else:
|
||||
logger.error(
|
||||
"Error: The AMADEUS_CLIENT_ID and AMADEUS_CLIENT_SECRET environmental "
|
||||
"variables have not been set. Visit the following link on how to "
|
||||
"acquire these authorization tokens: "
|
||||
"https://developers.amadeus.com/register"
|
||||
)
|
||||
return None
|
||||
|
||||
hostname = "test" # Default hostname
|
||||
if "AMADEUS_HOSTNAME" in os.environ:
|
||||
hostname = os.environ["AMADEUS_HOSTNAME"]
|
||||
|
||||
client = Client(client_id=client_id, client_secret=client_secret, hostname=hostname)
|
||||
|
||||
return client
|
||||
Reference in New Issue
Block a user