Source code for compass.services.provider
"""COMPASS service provider classes"""
import asyncio
import logging
import contextlib
from compass.services.queues import (
initialize_service_queue,
get_service_queue,
tear_down_service_queue,
)
from compass.exceptions import COMPASSValueError
logger = logging.getLogger(__name__)
class _RunningProvider:
"""A running provider for a single service"""
def __init__(self, service, queue):
"""
Parameters
----------
service : :class:`compass.services.base.Service`
An instance of a single async service to run.
queue : :class:`asyncio.Queue`
Queue object for the running service.
"""
self.service = service
self.queue = queue
self.jobs = set()
async def run(self):
"""Run the service."""
while True:
await self.submit_jobs()
await self.collect_responses()
async def submit_jobs(self):
"""Submit jobs from the queue to processing
The service can limit the number of submissions at a time by
implementing the ``can_process`` property.
If the queue is non-empty, the function takes jobs from it
iteratively and submits them until the ``can_process`` property
of the service returns ``False``. A call to ``can_process`` is
submitted between every job pulled from the queue, so enure that
method is performant. If the queue is empty, this function does
one of two things:
1) If there are no jobs processing, it waits on the queue
to get more jobs and submits them as they come in
(assuming the service allows it)
2) If there are jobs processing, this function returns
without waiting on more jobs from the queue.
"""
if not self.service.can_process or self._q_empty_but_still_processing:
return
while self.service.can_process and self._can_fit_jobs:
fut, outer_task_name, args, kwargs = await self.queue.get()
task = asyncio.create_task(
self.service.process_using_futures(fut, *args, **kwargs),
name=outer_task_name,
)
self.queue.task_done()
self.jobs.add(task)
await _allow_service_to_update()
return
@property
def _q_empty_but_still_processing(self):
"""bool: Queue empty but jobs still running (don't await)"""
return self.queue.empty() and self.jobs
@property
def _can_fit_jobs(self):
"""bool: Job tracker not full"""
return len(self.jobs) < self.service.MAX_CONCURRENT_JOBS
async def collect_responses(self):
"""Collect responses from the service.
This call will block further submissions to the service until
at least one job finishes.
"""
if not self.jobs:
return
complete, __ = await asyncio.wait(
self.jobs, return_when=asyncio.FIRST_COMPLETED
)
for job in complete:
self.jobs.remove(job)
[docs]
class RunningAsyncServices:
"""Async context manager for running services."""
def __init__(self, services):
"""
Parameters
----------
services : iterable
An iterable of async services to run during program
execution.
"""
self.services = services
self.__providers = []
self._validate_services()
def _validate_services(self):
"""Validate input services."""
if len(self.services) < 1:
msg = "Must provide at least one service to run!"
raise COMPASSValueError(msg)
def _reset_providers(self):
"""Reset running providers"""
for c in self.__providers:
c.cancel()
self.__providers = []
async def __aenter__(self):
for service in self.services:
logger.debug("Initializing Service: %s", service.name)
with contextlib.suppress(AttributeError):
logger.debug(
" ↪ model_name=%r, rate_limit=%d",
service.model_name,
service.rate_limit,
)
queue = initialize_service_queue(service.name)
service.acquire_resources()
task = asyncio.create_task(_RunningProvider(service, queue).run())
self.__providers.append(task)
async def __aexit__(self, exc_type, exc, tb):
try:
for service in self.services:
await get_service_queue(service.name).join()
service.release_resources()
finally:
self._reset_providers()
for service in self.services:
logger.debug("Tearing down Service: %s", service.name)
tear_down_service_queue(service.name)
[docs]
@classmethod
def run(cls, services, coroutine):
"""Run an async function that relies on services.
You can treat this function like the ``asyncio.run`` function
with an extra parameter::
openai_service = OpenAIService(...)
RunningAsyncServices.run(
[openai_service], my_async_func(*args, **kwargs)
)
Parameters
----------
services : iterable of :class:`compass.services.base.Service`
An iterable (i.e. a list) of Services that are needed to run
the asynchronous function.
coroutine : coroutine
A coroutine that should be run with the services.
Returns
-------
Any
Returns the output of the coroutine.
"""
return asyncio.run(cls._run_coroutine(services, coroutine))
@classmethod
async def _run_coroutine(cls, services, coroutine):
"""Run a coroutine under services."""
async with cls(services):
return await coroutine
async def _allow_service_to_update():
"""Switch contexts, allowing service to update if it can process"""
await asyncio.sleep(0)