initial commit
This commit is contained in:
768
venv/Lib/site-packages/langgraph/pregel/_runner.py
Normal file
768
venv/Lib/site-packages/langgraph/pregel/_runner.py
Normal file
@@ -0,0 +1,768 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from collections.abc import (
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
|
||||
from langgraph._internal._constants import (
|
||||
CONF,
|
||||
CONFIG_KEY_CALL,
|
||||
CONFIG_KEY_SCRATCHPAD,
|
||||
ERROR,
|
||||
INTERRUPT,
|
||||
NO_WRITES,
|
||||
RESUME,
|
||||
RETURN,
|
||||
)
|
||||
from langgraph._internal._future import chain_future, run_coroutine_threadsafe
|
||||
from langgraph._internal._scratchpad import PregelScratchpad
|
||||
from langgraph._internal._typing import MISSING
|
||||
from langgraph.constants import TAG_HIDDEN
|
||||
from langgraph.errors import GraphBubbleUp, GraphInterrupt
|
||||
from langgraph.pregel._algo import Call
|
||||
from langgraph.pregel._executor import Submit
|
||||
from langgraph.pregel._retry import arun_with_retry, run_with_retry
|
||||
from langgraph.types import (
|
||||
CachePolicy,
|
||||
PregelExecutableTask,
|
||||
RetryPolicy,
|
||||
)
|
||||
|
||||
F = TypeVar("F", concurrent.futures.Future, asyncio.Future)
|
||||
E = TypeVar("E", threading.Event, asyncio.Event)
|
||||
|
||||
# List of filenames to exclude from exception traceback
|
||||
# Note: Frames will be removed if they are the last frame in traceback, recursively
|
||||
EXCLUDED_FRAME_FNAMES = (
|
||||
"langgraph/pregel/retry.py",
|
||||
"langgraph/pregel/runner.py",
|
||||
"langgraph/pregel/executor.py",
|
||||
"langgraph/utils/runnable.py",
|
||||
"langchain_core/runnables/config.py",
|
||||
"concurrent/futures/thread.py",
|
||||
"concurrent/futures/_base.py",
|
||||
)
|
||||
|
||||
SKIP_RERAISE_SET: weakref.WeakSet[concurrent.futures.Future | asyncio.Future] = (
|
||||
weakref.WeakSet()
|
||||
)
|
||||
|
||||
|
||||
class FuturesDict(Generic[F, E], dict[F, PregelExecutableTask | None]):
|
||||
event: E
|
||||
callback: weakref.ref[Callable[[PregelExecutableTask, BaseException | None], None]]
|
||||
counter: int
|
||||
done: set[F]
|
||||
lock: threading.Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event: E,
|
||||
callback: weakref.ref[
|
||||
Callable[[PregelExecutableTask, BaseException | None], None]
|
||||
],
|
||||
future_type: type[F],
|
||||
# used for generic typing, newer py supports FutureDict[...](...)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lock = threading.Lock()
|
||||
self.event = event
|
||||
self.callback = callback
|
||||
self.counter = 0
|
||||
self.done: set[F] = set()
|
||||
|
||||
def __setitem__(
|
||||
self,
|
||||
key: F,
|
||||
value: PregelExecutableTask | None,
|
||||
) -> None:
|
||||
super().__setitem__(key, value) # type: ignore[index]
|
||||
if value is not None:
|
||||
with self.lock:
|
||||
self.event.clear()
|
||||
self.counter += 1
|
||||
key.add_done_callback(partial(self.on_done, value))
|
||||
|
||||
def on_done(
|
||||
self,
|
||||
task: PregelExecutableTask,
|
||||
fut: F,
|
||||
) -> None:
|
||||
try:
|
||||
if cb := self.callback():
|
||||
cb(task, _exception(fut))
|
||||
finally:
|
||||
with self.lock:
|
||||
self.done.add(fut)
|
||||
self.counter -= 1
|
||||
if self.counter == 0 or _should_stop_others(self.done):
|
||||
self.event.set()
|
||||
|
||||
|
||||
class PregelRunner:
|
||||
"""Responsible for executing a set of Pregel tasks concurrently, committing
|
||||
their writes, yielding control to caller when there is output to emit, and
|
||||
interrupting other tasks if appropriate."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
submit: weakref.ref[Submit],
|
||||
put_writes: weakref.ref[Callable[[str, Sequence[tuple[str, Any]]], None]],
|
||||
use_astream: bool = False,
|
||||
node_finished: Callable[[str], None] | None = None,
|
||||
) -> None:
|
||||
self.submit = submit
|
||||
self.put_writes = put_writes
|
||||
self.use_astream = use_astream
|
||||
self.node_finished = node_finished
|
||||
|
||||
def tick(
|
||||
self,
|
||||
tasks: Iterable[PregelExecutableTask],
|
||||
*,
|
||||
reraise: bool = True,
|
||||
timeout: float | None = None,
|
||||
retry_policy: Sequence[RetryPolicy] | None = None,
|
||||
get_waiter: Callable[[], concurrent.futures.Future[None]] | None = None,
|
||||
schedule_task: Callable[
|
||||
[PregelExecutableTask, int, Call | None],
|
||||
PregelExecutableTask | None,
|
||||
],
|
||||
) -> Iterator[None]:
|
||||
tasks = tuple(tasks)
|
||||
futures = FuturesDict(
|
||||
callback=weakref.WeakMethod(self.commit),
|
||||
event=threading.Event(),
|
||||
future_type=concurrent.futures.Future,
|
||||
)
|
||||
# give control back to the caller
|
||||
yield
|
||||
# fast path if single task with no timeout and no waiter
|
||||
if len(tasks) == 0:
|
||||
return
|
||||
elif len(tasks) == 1 and timeout is None and get_waiter is None:
|
||||
t = tasks[0]
|
||||
try:
|
||||
run_with_retry(
|
||||
t,
|
||||
retry_policy,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_call,
|
||||
weakref.ref(t),
|
||||
retry_policy=retry_policy,
|
||||
futures=weakref.ref(futures),
|
||||
schedule_task=schedule_task,
|
||||
submit=self.submit,
|
||||
),
|
||||
},
|
||||
)
|
||||
self.commit(t, None)
|
||||
except Exception as exc:
|
||||
self.commit(t, exc)
|
||||
if reraise and futures:
|
||||
# will be re-raised after futures are done
|
||||
fut: concurrent.futures.Future = concurrent.futures.Future()
|
||||
fut.set_exception(exc)
|
||||
futures.done.add(fut)
|
||||
elif reraise:
|
||||
if tb := exc.__traceback__:
|
||||
while tb.tb_next is not None and any(
|
||||
tb.tb_frame.f_code.co_filename.endswith(name)
|
||||
for name in EXCLUDED_FRAME_FNAMES
|
||||
):
|
||||
tb = tb.tb_next
|
||||
exc.__traceback__ = tb
|
||||
raise
|
||||
if not futures: # maybe `t` scheduled another task
|
||||
return
|
||||
else:
|
||||
tasks = () # don't reschedule this task
|
||||
# add waiter task if requested
|
||||
if get_waiter is not None:
|
||||
futures[get_waiter()] = None
|
||||
# schedule tasks
|
||||
for t in tasks:
|
||||
fut = self.submit()( # type: ignore[misc]
|
||||
run_with_retry,
|
||||
t,
|
||||
retry_policy,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_call,
|
||||
weakref.ref(t),
|
||||
retry_policy=retry_policy,
|
||||
futures=weakref.ref(futures),
|
||||
schedule_task=schedule_task,
|
||||
submit=self.submit,
|
||||
),
|
||||
},
|
||||
__reraise_on_exit__=reraise,
|
||||
)
|
||||
futures[fut] = t
|
||||
# execute tasks, and wait for one to fail or all to finish.
|
||||
# each task is independent from all other concurrent tasks
|
||||
# yield updates/debug output as each task finishes
|
||||
end_time = timeout + time.monotonic() if timeout else None
|
||||
while len(futures) > (1 if get_waiter is not None else 0):
|
||||
done, inflight = concurrent.futures.wait(
|
||||
futures,
|
||||
return_when=concurrent.futures.FIRST_COMPLETED,
|
||||
timeout=(max(0, end_time - time.monotonic()) if end_time else None),
|
||||
)
|
||||
if not done:
|
||||
break # timed out
|
||||
for fut in done:
|
||||
task = futures.pop(fut)
|
||||
if task is None:
|
||||
# waiter task finished, schedule another
|
||||
if inflight and get_waiter is not None:
|
||||
futures[get_waiter()] = None
|
||||
else:
|
||||
# remove references to loop vars
|
||||
del fut, task
|
||||
# maybe stop other tasks
|
||||
if _should_stop_others(done):
|
||||
break
|
||||
# give control back to the caller
|
||||
yield
|
||||
# wait for done callbacks
|
||||
futures.event.wait(
|
||||
timeout=(max(0, end_time - time.monotonic()) if end_time else None)
|
||||
)
|
||||
# give control back to the caller
|
||||
yield
|
||||
# panic on failure or timeout
|
||||
try:
|
||||
_panic_or_proceed(
|
||||
futures.done.union(f for f, t in futures.items() if t is not None),
|
||||
panic=reraise,
|
||||
)
|
||||
except Exception as exc:
|
||||
if tb := exc.__traceback__:
|
||||
while tb.tb_next is not None and any(
|
||||
tb.tb_frame.f_code.co_filename.endswith(name)
|
||||
for name in EXCLUDED_FRAME_FNAMES
|
||||
):
|
||||
tb = tb.tb_next
|
||||
exc.__traceback__ = tb
|
||||
raise
|
||||
|
||||
async def atick(
|
||||
self,
|
||||
tasks: Iterable[PregelExecutableTask],
|
||||
*,
|
||||
reraise: bool = True,
|
||||
timeout: float | None = None,
|
||||
retry_policy: Sequence[RetryPolicy] | None = None,
|
||||
get_waiter: Callable[[], asyncio.Future[None]] | None = None,
|
||||
schedule_task: Callable[
|
||||
[PregelExecutableTask, int, Call | None],
|
||||
Awaitable[PregelExecutableTask | None],
|
||||
],
|
||||
) -> AsyncIterator[None]:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
tasks = tuple(tasks)
|
||||
futures = FuturesDict(
|
||||
callback=weakref.WeakMethod(self.commit),
|
||||
event=asyncio.Event(),
|
||||
future_type=asyncio.Future,
|
||||
)
|
||||
# give control back to the caller
|
||||
yield
|
||||
# fast path if single task with no waiter and no timeout
|
||||
if len(tasks) == 0:
|
||||
return
|
||||
elif len(tasks) == 1 and get_waiter is None and timeout is None:
|
||||
t = tasks[0]
|
||||
try:
|
||||
await arun_with_retry(
|
||||
t,
|
||||
retry_policy,
|
||||
stream=self.use_astream,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_acall,
|
||||
weakref.ref(t),
|
||||
stream=self.use_astream,
|
||||
retry_policy=retry_policy,
|
||||
futures=weakref.ref(futures),
|
||||
schedule_task=schedule_task,
|
||||
submit=self.submit,
|
||||
loop=loop,
|
||||
),
|
||||
},
|
||||
)
|
||||
self.commit(t, None)
|
||||
except Exception as exc:
|
||||
self.commit(t, exc)
|
||||
if reraise and futures:
|
||||
# will be re-raised after futures are done
|
||||
fut: asyncio.Future = loop.create_future()
|
||||
fut.set_exception(exc)
|
||||
futures.done.add(fut)
|
||||
elif reraise:
|
||||
if tb := exc.__traceback__:
|
||||
while tb.tb_next is not None and any(
|
||||
tb.tb_frame.f_code.co_filename.endswith(name)
|
||||
for name in EXCLUDED_FRAME_FNAMES
|
||||
):
|
||||
tb = tb.tb_next
|
||||
exc.__traceback__ = tb
|
||||
raise
|
||||
if not futures: # maybe `t` scheduled another task
|
||||
return
|
||||
else:
|
||||
tasks = () # don't reschedule this task
|
||||
# add waiter task if requested
|
||||
if get_waiter is not None:
|
||||
futures[get_waiter()] = None
|
||||
# schedule tasks
|
||||
for t in tasks:
|
||||
fut = cast(
|
||||
asyncio.Future,
|
||||
self.submit()( # type: ignore[misc]
|
||||
arun_with_retry,
|
||||
t,
|
||||
retry_policy,
|
||||
stream=self.use_astream,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_acall,
|
||||
weakref.ref(t),
|
||||
retry_policy=retry_policy,
|
||||
stream=self.use_astream,
|
||||
futures=weakref.ref(futures),
|
||||
schedule_task=schedule_task,
|
||||
submit=self.submit,
|
||||
loop=loop,
|
||||
),
|
||||
},
|
||||
__name__=t.name,
|
||||
__cancel_on_exit__=True,
|
||||
__reraise_on_exit__=reraise,
|
||||
),
|
||||
)
|
||||
futures[fut] = t
|
||||
# execute tasks, and wait for one to fail or all to finish.
|
||||
# each task is independent from all other concurrent tasks
|
||||
# yield updates/debug output as each task finishes
|
||||
end_time = timeout + loop.time() if timeout else None
|
||||
while len(futures) > (1 if get_waiter is not None else 0):
|
||||
done, inflight = await asyncio.wait(
|
||||
futures,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
timeout=(max(0, end_time - loop.time()) if end_time else None),
|
||||
)
|
||||
if not done:
|
||||
break # timed out
|
||||
for fut in done:
|
||||
task = futures.pop(fut)
|
||||
if task is None:
|
||||
# waiter task finished, schedule another
|
||||
if inflight and get_waiter is not None:
|
||||
futures[get_waiter()] = None
|
||||
else:
|
||||
# remove references to loop vars
|
||||
del fut, task
|
||||
# maybe stop other tasks
|
||||
if _should_stop_others(done):
|
||||
break
|
||||
# give control back to the caller
|
||||
yield
|
||||
# wait for done callbacks
|
||||
await asyncio.wait_for(
|
||||
futures.event.wait(),
|
||||
timeout=(max(0, end_time - loop.time()) if end_time else None),
|
||||
)
|
||||
# give control back to the caller
|
||||
yield
|
||||
# cancel waiter task
|
||||
for fut in futures:
|
||||
fut.cancel()
|
||||
# panic on failure or timeout
|
||||
try:
|
||||
_panic_or_proceed(
|
||||
futures.done.union(f for f, t in futures.items() if t is not None),
|
||||
timeout_exc_cls=asyncio.TimeoutError,
|
||||
panic=reraise,
|
||||
)
|
||||
except Exception as exc:
|
||||
if tb := exc.__traceback__:
|
||||
while tb.tb_next is not None and any(
|
||||
tb.tb_frame.f_code.co_filename.endswith(name)
|
||||
for name in EXCLUDED_FRAME_FNAMES
|
||||
):
|
||||
tb = tb.tb_next
|
||||
exc.__traceback__ = tb
|
||||
raise
|
||||
|
||||
def commit(
|
||||
self,
|
||||
task: PregelExecutableTask,
|
||||
exception: BaseException | None,
|
||||
) -> None:
|
||||
if isinstance(exception, asyncio.CancelledError):
|
||||
# for cancelled tasks, also save error in task,
|
||||
# so loop can finish super-step
|
||||
task.writes.append((ERROR, exception))
|
||||
self.put_writes()(task.id, task.writes) # type: ignore[misc]
|
||||
elif exception:
|
||||
if isinstance(exception, GraphInterrupt):
|
||||
# save interrupt to checkpointer
|
||||
if exception.args[0]:
|
||||
writes = [(INTERRUPT, exception.args[0])]
|
||||
if resumes := [w for w in task.writes if w[0] == RESUME]:
|
||||
writes.extend(resumes)
|
||||
self.put_writes()(task.id, writes) # type: ignore[misc]
|
||||
elif isinstance(exception, GraphBubbleUp):
|
||||
# exception will be raised in _panic_or_proceed
|
||||
pass
|
||||
else:
|
||||
# save error to checkpointer
|
||||
task.writes.append((ERROR, exception))
|
||||
self.put_writes()(task.id, task.writes) # type: ignore[misc]
|
||||
else:
|
||||
if self.node_finished and (
|
||||
task.config is None or TAG_HIDDEN not in task.config.get("tags", [])
|
||||
):
|
||||
self.node_finished(task.name)
|
||||
if not task.writes:
|
||||
# add no writes marker
|
||||
task.writes.append((NO_WRITES, None))
|
||||
# save task writes to checkpointer
|
||||
self.put_writes()(task.id, task.writes) # type: ignore[misc]
|
||||
|
||||
|
||||
def _should_stop_others(
|
||||
done: set[F],
|
||||
) -> bool:
|
||||
"""Check if any task failed, if so, cancel all other tasks.
|
||||
GraphInterrupts are not considered failures."""
|
||||
for fut in done:
|
||||
if fut.cancelled():
|
||||
continue
|
||||
elif exc := fut.exception():
|
||||
if not isinstance(exc, GraphBubbleUp) and fut not in SKIP_RERAISE_SET:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _exception(
|
||||
fut: concurrent.futures.Future[Any] | asyncio.Future[Any],
|
||||
) -> BaseException | None:
|
||||
"""Return the exception from a future, without raising CancelledError."""
|
||||
if fut.cancelled():
|
||||
if isinstance(fut, asyncio.Future):
|
||||
return asyncio.CancelledError()
|
||||
else:
|
||||
return concurrent.futures.CancelledError()
|
||||
else:
|
||||
return fut.exception()
|
||||
|
||||
|
||||
def _panic_or_proceed(
|
||||
futs: set[concurrent.futures.Future] | set[asyncio.Future],
|
||||
*,
|
||||
timeout_exc_cls: type[Exception] = TimeoutError,
|
||||
panic: bool = True,
|
||||
) -> None:
|
||||
"""Cancel remaining tasks if any failed, re-raise exception if panic is True."""
|
||||
done: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
|
||||
inflight: set[concurrent.futures.Future[Any] | asyncio.Future[Any]] = set()
|
||||
for fut in futs:
|
||||
if fut.cancelled():
|
||||
continue
|
||||
elif fut.done():
|
||||
done.add(fut)
|
||||
else:
|
||||
inflight.add(fut)
|
||||
interrupts: list[GraphInterrupt] = []
|
||||
while done:
|
||||
# if any task failed
|
||||
fut = done.pop()
|
||||
if exc := _exception(fut):
|
||||
# cancel all pending tasks
|
||||
while inflight:
|
||||
inflight.pop().cancel()
|
||||
# raise the exception
|
||||
if panic:
|
||||
if isinstance(exc, GraphInterrupt):
|
||||
# collect interrupts
|
||||
interrupts.append(exc)
|
||||
elif fut not in SKIP_RERAISE_SET:
|
||||
raise exc
|
||||
# raise combined interrupts
|
||||
if interrupts:
|
||||
raise GraphInterrupt(tuple(i for exc in interrupts for i in exc.args[0]))
|
||||
if inflight:
|
||||
# if we got here means we timed out
|
||||
while inflight:
|
||||
# cancel all pending tasks
|
||||
inflight.pop().cancel()
|
||||
# raise timeout error
|
||||
raise timeout_exc_cls("Timed out")
|
||||
|
||||
|
||||
def _call(
|
||||
task: weakref.ref[PregelExecutableTask],
|
||||
func: Callable[[Any], Awaitable[Any] | Any],
|
||||
input: Any,
|
||||
*,
|
||||
retry_policy: Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
callbacks: Callbacks = None,
|
||||
futures: weakref.ref[FuturesDict],
|
||||
schedule_task: Callable[
|
||||
[PregelExecutableTask, int, Call | None], PregelExecutableTask | None
|
||||
],
|
||||
submit: weakref.ref[Submit],
|
||||
) -> concurrent.futures.Future[Any]:
|
||||
if inspect.iscoroutinefunction(func):
|
||||
raise RuntimeError("In an sync context async tasks cannot be called")
|
||||
|
||||
fut: concurrent.futures.Future | None = None
|
||||
# schedule PUSH tasks, collect futures
|
||||
scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
|
||||
# schedule the next task, if the callback returns one
|
||||
if next_task := schedule_task(
|
||||
task(), # type: ignore[arg-type]
|
||||
scratchpad.call_counter(),
|
||||
Call(
|
||||
func,
|
||||
input,
|
||||
retry_policy=retry_policy,
|
||||
cache_policy=cache_policy,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
):
|
||||
if fut := next(
|
||||
(
|
||||
f
|
||||
for f, t in list(futures().items()) # type: ignore[union-attr]
|
||||
if t is not None and t == next_task.id
|
||||
),
|
||||
None,
|
||||
):
|
||||
# if the parent task was retried,
|
||||
# the next task might already be running
|
||||
pass
|
||||
elif next_task.writes:
|
||||
# if it already ran, return the result
|
||||
fut = concurrent.futures.Future()
|
||||
ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
|
||||
if ret is not MISSING:
|
||||
fut.set_result(ret)
|
||||
elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
|
||||
fut.set_exception(
|
||||
exc if isinstance(exc, BaseException) else Exception(exc)
|
||||
)
|
||||
else:
|
||||
fut.set_result(None)
|
||||
else:
|
||||
# schedule the next task
|
||||
fut = submit()( # type: ignore[misc]
|
||||
run_with_retry,
|
||||
next_task,
|
||||
retry_policy,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_call,
|
||||
weakref.ref(next_task),
|
||||
futures=futures,
|
||||
retry_policy=retry_policy,
|
||||
callbacks=callbacks,
|
||||
schedule_task=schedule_task,
|
||||
submit=submit,
|
||||
),
|
||||
},
|
||||
__reraise_on_exit__=False,
|
||||
# starting a new task in the next tick ensures
|
||||
# updates from this tick are committed/streamed first
|
||||
__next_tick__=True,
|
||||
)
|
||||
# exceptions for call() tasks are raised into the parent task
|
||||
# so we should not re-raise at the end of the tick
|
||||
SKIP_RERAISE_SET.add(fut)
|
||||
futures()[fut] = next_task # type: ignore[index]
|
||||
fut = cast(asyncio.Future | concurrent.futures.Future, fut)
|
||||
# return a chained future to ensure commit() callback is called
|
||||
# before the returned future is resolved, to ensure stream order etc
|
||||
return chain_future(fut, concurrent.futures.Future())
|
||||
|
||||
|
||||
def _acall(
|
||||
task: weakref.ref[PregelExecutableTask],
|
||||
func: Callable[[Any], Awaitable[Any] | Any],
|
||||
input: Any,
|
||||
*,
|
||||
retry_policy: Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
callbacks: Callbacks = None,
|
||||
# injected dependencies
|
||||
futures: weakref.ref[FuturesDict],
|
||||
schedule_task: Callable[
|
||||
[PregelExecutableTask, int, Call | None],
|
||||
Awaitable[PregelExecutableTask | None],
|
||||
],
|
||||
submit: weakref.ref[Submit],
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
stream: bool = False,
|
||||
) -> asyncio.Future[Any] | concurrent.futures.Future[Any]:
|
||||
# return a chained future to ensure commit() callback is called
|
||||
# before the returned future is resolved, to ensure stream order etc
|
||||
try:
|
||||
in_async = asyncio.current_task() is not None
|
||||
except RuntimeError:
|
||||
in_async = False
|
||||
# if in async context return an async future, otherwise return a sync future
|
||||
if in_async:
|
||||
fut: asyncio.Future[Any] | concurrent.futures.Future[Any] = asyncio.Future(
|
||||
loop=loop
|
||||
)
|
||||
else:
|
||||
fut = concurrent.futures.Future()
|
||||
# schedule the next task
|
||||
run_coroutine_threadsafe(
|
||||
_acall_impl(
|
||||
fut,
|
||||
task,
|
||||
func,
|
||||
input,
|
||||
retry_policy=retry_policy,
|
||||
cache_policy=cache_policy,
|
||||
callbacks=callbacks,
|
||||
futures=futures,
|
||||
schedule_task=schedule_task,
|
||||
submit=submit,
|
||||
loop=loop,
|
||||
stream=stream,
|
||||
),
|
||||
loop,
|
||||
lazy=False,
|
||||
)
|
||||
return fut
|
||||
|
||||
|
||||
async def _acall_impl(
|
||||
destination: asyncio.Future[Any] | concurrent.futures.Future[Any],
|
||||
task: weakref.ref[PregelExecutableTask],
|
||||
func: Callable[[Any], Awaitable[Any] | Any],
|
||||
input: Any,
|
||||
*,
|
||||
retry_policy: Sequence[RetryPolicy] | None = None,
|
||||
cache_policy: CachePolicy | None = None,
|
||||
callbacks: Callbacks = None,
|
||||
# injected dependencies
|
||||
futures: weakref.ref[FuturesDict[asyncio.Future, asyncio.Event]],
|
||||
schedule_task: Callable[
|
||||
[PregelExecutableTask, int, Call | None],
|
||||
Awaitable[PregelExecutableTask | None],
|
||||
],
|
||||
submit: weakref.ref[Submit],
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
try:
|
||||
fut: asyncio.Future | None = None
|
||||
# schedule PUSH tasks, collect futures
|
||||
scratchpad: PregelScratchpad = task().config[CONF][CONFIG_KEY_SCRATCHPAD] # type: ignore[union-attr]
|
||||
# schedule the next task, if the callback returns one
|
||||
if next_task := await schedule_task(
|
||||
task(), # type: ignore[arg-type]
|
||||
scratchpad.call_counter(),
|
||||
Call(
|
||||
func,
|
||||
input,
|
||||
retry_policy=retry_policy,
|
||||
cache_policy=cache_policy,
|
||||
callbacks=callbacks,
|
||||
),
|
||||
):
|
||||
if fut := next(
|
||||
(
|
||||
f
|
||||
for f, t in list(futures().items()) # type: ignore[union-attr]
|
||||
if t is not None and t == next_task.id
|
||||
),
|
||||
None,
|
||||
):
|
||||
# if the parent task was retried,
|
||||
# the next task might already be running
|
||||
pass
|
||||
elif next_task.writes:
|
||||
# if it already ran, return the result
|
||||
fut = asyncio.Future(loop=loop)
|
||||
ret = next((v for c, v in next_task.writes if c == RETURN), MISSING)
|
||||
if ret is not MISSING:
|
||||
fut.set_result(ret)
|
||||
elif exc := next((v for c, v in next_task.writes if c == ERROR), None):
|
||||
fut.set_exception(
|
||||
exc if isinstance(exc, BaseException) else Exception(exc)
|
||||
)
|
||||
else:
|
||||
fut.set_result(None)
|
||||
else:
|
||||
# schedule the next task
|
||||
fut = cast(
|
||||
asyncio.Future,
|
||||
submit()( # type: ignore[misc]
|
||||
arun_with_retry,
|
||||
next_task,
|
||||
retry_policy,
|
||||
stream=stream,
|
||||
configurable={
|
||||
CONFIG_KEY_CALL: partial(
|
||||
_acall,
|
||||
weakref.ref(next_task),
|
||||
stream=stream,
|
||||
futures=futures,
|
||||
schedule_task=schedule_task,
|
||||
submit=submit,
|
||||
loop=loop,
|
||||
),
|
||||
},
|
||||
__name__=next_task.name,
|
||||
__cancel_on_exit__=True,
|
||||
__reraise_on_exit__=False,
|
||||
# starting a new task in the next tick ensures
|
||||
# updates from this tick are committed/streamed first
|
||||
__next_tick__=True,
|
||||
),
|
||||
)
|
||||
# exceptions for call() tasks are raised into the parent task
|
||||
# so we should not re-raise at the end of the tick
|
||||
SKIP_RERAISE_SET.add(fut)
|
||||
futures()[fut] = next_task # type: ignore[index]
|
||||
if fut is not None:
|
||||
chain_future(fut, destination)
|
||||
else:
|
||||
destination.set_exception(RuntimeError("Task not scheduled"))
|
||||
except Exception as exc:
|
||||
destination.set_exception(exc)
|
||||
Reference in New Issue
Block a user