Source code for compass.common.tree
"""Ordinance async decision tree"""
import networkx as nx
import logging
from elm.tree import DecisionTree
from compass.utilities.enums import LLMUsageCategory
from compass.exceptions import COMPASSRuntimeError
logger = logging.getLogger(__name__)
[docs]
class AsyncDecisionTree(DecisionTree):
"""Async class to traverse a directed graph of LLM prompts
Nodes of this tree are prompts, and edges are transitions between
prompts based on conditions being met in the LLM response
Purpose:
Represent a series of prompts that can be used in sequence to
extract values of interest from text.
Responsibilities:
1. Store all prompts used to extract a particular ordinance
value from text.
2. Track relationships between the prompts (i.e. which prompts
is used first, which prompt is used next depending on the
output of the previous prompt, etc.) using a directed acyclic
graph.
Key Relationships:
Inherits from :class:`~elm.tree.DecisionTree` to add ``async``
capabilities. Uses a
:class:`~compass.llm.calling.ChatLLMCaller` for LLm queries.
"""
def __init__(self, graph, usage_sub_label=None):
"""
Parameters
----------
graph : nx.DiGraph
Directed acyclic graph where nodes are LLM prompts and edges
are logical transitions based on the response. Must have
high-level graph attribute "chat_llm_caller" which is a
ChatLLMCaller instance. Nodes should have attribute "prompt"
which can have {format} named arguments that will be filled
from the high-level graph attributes. Edges can have
attribute "condition" that is a callable to be executed on
the LLM response text. An edge from a node without a
condition acts as an "else" statement if no other edge
conditions are satisfied. A single edge from node to node
does not need a condition.
usage_sub_label : str, optional
Optional label to classify LLM usage under when running this
decision tree. If ``None``, will simply label calls made
from this tree under "decision_tree". By default, ``None``.
"""
self._g = graph
self._history = []
self.usage_sub_label = (
usage_sub_label or LLMUsageCategory.DECISION_TREE
)
assert isinstance(self.graph, nx.DiGraph)
assert "chat_llm_caller" in self.graph.graph
@property
def chat_llm_caller(self):
"""ChatLLMCaller: ChatLLMCaller instance for this tree"""
return self.graph.graph["chat_llm_caller"]
@property
def messages(self):
"""Get a list of the conversation messages with the LLM
Returns
-------
list
"""
return self.chat_llm_caller.messages
@property
def all_messages_txt(self):
"""Get a printout of the full conversation with the LLM
Returns
-------
str
"""
messages = [
f"{msg['role'].upper()}: {msg['content']}" for msg in self.messages
]
return "\n\n".join(messages)
[docs]
async def async_call_node(self, node0):
"""Call the LLM
The chat will start with the prompt from the input node and will
search the successor edges for a valid transition condition.
Parameters
----------
node0 : str
Name of node being executed.
Returns
-------
out : str
Next node or LLM response if at a leaf node.
"""
prompt = self._prepare_graph_call(node0)
out = await self.chat_llm_caller.call(
prompt, usage_sub_label=self.usage_sub_label
)
logger.debug_to_file(
"Chat GPT prompt:\n%s\nChat GPT response:\n%s", prompt, out
)
return self._parse_graph_output(node0, out or "")
[docs]
async def async_run(self, node0="init"):
"""Traverse the decision tree starting at the input node
Parameters
----------
node0 : str
Name of starting node in the graph. This is typically called
"init".
Returns
-------
out : str | None
Final response from LLM at the leaf node or ``None`` if an
``AttributeError`` was raised during execution.
"""
self._history = []
while True:
try:
out = await self.async_call_node(node0)
except AttributeError:
logger.debug_to_file(
"Error traversing trees, here's the full "
"conversation printout:\n%s",
self.all_messages_txt,
)
return None
except Exception as e:
logger.debug_to_file(
"Error traversing trees, here's the full "
"conversation printout:\n%s",
self.all_messages_txt,
)
last_message = self.messages[-1]["content"]
msg = (
"Ran into an exception when traversing tree. "
"Last message from LLM is printed below. "
"See debug logs for more detail. "
"\nLast message: \n"
'"""\n%s\n"""'
)
logger.exception(msg, last_message)
raise COMPASSRuntimeError(msg % last_message) from e
if out in self.graph:
node0 = out
else:
break
logger.info("Final decision tree output: %s", out)
return out