"""COMPASS Ordinance Location Validation logic
These are primarily used to validate that a legal document applies to a
particular location.
"""
import asyncio
import logging
from elm.web.file_loader import AsyncFileLoader
from compass.llm.calling import BaseLLMCaller, ChatLLMCaller, LLMCaller
from compass.common import setup_async_decision_tree, run_async_tree
from compass.validation.graphs import (
setup_graph_correct_jurisdiction_type,
setup_graph_correct_jurisdiction_from_url,
)
from compass.utilities.enums import LLMUsageCategory
logger = logging.getLogger(__name__)
[docs]
class DTreeURLJurisdictionValidator(BaseLLMCaller):
"""Validator that checks whether a URL matches a jurisdiction"""
SYSTEM_MESSAGE = (
"You are an expert data analyst that examines URLs to determine if "
"they contain information about jurisdictions. Only ever answer "
"based on the information in the URL itself."
)
def __init__(self, jurisdiction, **kwargs):
"""
Parameters
----------
structured_llm_caller : `StructuredLLMCaller`
StructuredLLMCaller instance. Used for structured validation
queries.
**kwargs
Additional keyword arguments to pass to the
:class:`~compass.llm.calling.BaseLLMCaller` instance.
"""
super().__init__(**kwargs)
self.jurisdiction = jurisdiction
[docs]
async def check(self, url):
"""Check if the content passes the validation
Parameters
----------
content : str
Document content to validate.
Returns
-------
bool
``True`` if the content passes the validation check,
``False`` otherwise.
"""
if not url:
return False
chat_llm_caller = ChatLLMCaller(
llm_service=self.llm_service,
system_message=self.SYSTEM_MESSAGE,
usage_tracker=self.usage_tracker,
**self.kwargs,
)
tree = setup_async_decision_tree(
setup_graph_correct_jurisdiction_from_url,
usage_sub_label=LLMUsageCategory.URL_JURISDICTION_VALIDATION,
jurisdiction=self.jurisdiction,
url=url,
chat_llm_caller=chat_llm_caller,
)
out = await run_async_tree(tree, response_as_json=True)
return self._parse_output(out)
def _parse_output(self, props): # noqa: PLR6301
"""Parse LLM response and return boolean validation result"""
logger.debug(
"Parsing URL jurisdiction validation output:\n\t%s", props
)
return len(props) > 0 and all(props.values())
[docs]
class DTreeJurisdictionValidator(BaseLLMCaller):
"""Jurisdiction Validation using a decision tree"""
META_SCORE_KEY = "Jurisdiction Validation Score"
SYSTEM_MESSAGE = (
"You are a legal expert assisting a user with determining the scope "
"of applicability for their legal ordinance documents."
)
def __init__(self, jurisdiction, **kwargs):
"""
Parameters
----------
structured_llm_caller : `StructuredLLMCaller`
StructuredLLMCaller instance. Used for structured validation
queries.
**kwargs
Additional keyword arguments to pass to the
:class:`~compass.llm.calling.BaseLLMCaller` instance.
"""
super().__init__(**kwargs)
self.jurisdiction = jurisdiction
[docs]
async def check(self, content):
"""Check if the content passes the validation
Parameters
----------
content : str
Document content to validate.
Returns
-------
bool
``True`` if the content passes the validation check,
``False`` otherwise.
"""
if not content:
return False
chat_llm_caller = ChatLLMCaller(
llm_service=self.llm_service,
system_message=self.SYSTEM_MESSAGE,
usage_tracker=self.usage_tracker,
**self.kwargs,
)
tree = setup_async_decision_tree(
setup_graph_correct_jurisdiction_type,
usage_sub_label=LLMUsageCategory.DOCUMENT_JURISDICTION_VALIDATION,
jurisdiction=self.jurisdiction,
text=content,
chat_llm_caller=chat_llm_caller,
)
out = await run_async_tree(tree, response_as_json=True)
return self._parse_output(out)
def _parse_output(self, props): # noqa: PLR6301
"""Parse LLM response and return boolean validation result"""
logger.debug(
"Parsing county jurisdiction validation output:\n\t%s", props
)
return props.get("correct_jurisdiction")
[docs]
class JurisdictionValidator:
"""COMPASS Ordinance Jurisdiction validator
Combines the logic of several validators into a single class.
Purpose:
Determine whether a document pertains to a specific county.
Responsibilities:
1. Use a combination of heuristics and LLM queries to determine
whether or not a document pertains to a particular county.
Key Relationships:
Uses a :class:`~compass.llm.calling.StructuredLLMCaller` for
LLM queries and delegates sub-validation to
:class:`DTreeJurisdictionValidator`,
and :class:`DTreeURLJurisdictionValidator`.
"""
def __init__(self, score_thresh=0.8, text_splitter=None, **kwargs):
"""
Parameters
----------
score_thresh : float, optional
Score threshold to exceed when voting on content from raw
pages. By default, ``0.8``.
text_splitter : langchain.text_splitter.TextSplitter, optional
Optional text splitter instance to attach to doc (used for
splitting out pages in an HTML document).
By default, ``None``.
**kwargs
Additional keyword arguments to pass to the
:class:`~compass.llm.calling.BaseLLMCaller` instance.
"""
self.score_thresh = score_thresh
self.text_splitter = text_splitter
self.kwargs = kwargs
[docs]
async def check(self, doc, jurisdiction):
"""Check if the document belongs to the county
Parameters
----------
doc : :class:`elm.web.document.BaseDocument`
Document instance. Should contain a "source" key in the
``attrs`` that contains a URL (used for the URL validation
check). Raw content will be parsed for county name and
correct jurisdiction.
Returns
-------
bool
`True` if the doc contents pertain to the input county.
`False` otherwise.
"""
if hasattr(doc, "text_splitter") and self.text_splitter is not None:
old_splitter = doc.text_splitter
doc.text_splitter = self.text_splitter
out = await self._check(doc, jurisdiction)
doc.text_splitter = old_splitter
return out
return await self._check(doc, jurisdiction)
async def _check(self, doc, jurisdiction):
"""Check if the document belongs to the county"""
if self.text_splitter is not None:
doc.text_splitter = self.text_splitter
url = doc.attrs.get("source")
if url:
logger.debug("Checking URL (%s) for jurisdiction name...", url)
url_validator = DTreeURLJurisdictionValidator(
jurisdiction, **self.kwargs
)
url_is_correct_jurisdiction = await url_validator.check(url)
if url_is_correct_jurisdiction:
return True
logger.info("Validating document from source: %s", url or "Unknown")
logger.debug("Checking for correct for jurisdiction...")
jurisdiction_validator = DTreeJurisdictionValidator(
jurisdiction, **self.kwargs
)
return await _validator_check_for_doc(
validator=jurisdiction_validator,
doc=doc,
score_thresh=self.score_thresh,
)
[docs]
class JurisdictionWebsiteValidator:
"""COMPASS Ordinance Jurisdiction Website validator"""
WEB_PAGE_CHECK_SYSTEM_MESSAGE = (
"You are an expert data analyst that examines website text to "
"determine if the website is the main website for a given "
"jurisdiction. Only ever answer based on the information from the "
"website itself."
)
def __init__(
self, browser_semaphore=None, file_loader_kwargs=None, **kwargs
):
"""
Parameters
----------
browser_semaphore : :class:`asyncio.Semaphore`, optional
Semaphore instance that can be used to limit the number of
playwright browsers open concurrently. If ``None``, no
limits are applied. By default, ``None``.
file_loader_kwargs : dict, optional
Dictionary of keyword arguments pairs to initialize
:class:`elm.web.file_loader.AsyncFileLoader`.
By default, ``None``.
**kwargs
Additional keyword arguments to pass to the
:class:`~compass.llm.calling.BaseLLMCaller` instance.
"""
self.browser_semaphore = browser_semaphore
self.file_loader_kwargs = file_loader_kwargs or {}
self.kwargs = kwargs
[docs]
async def check(self, url, jurisdiction):
"""Check if the website is the main website for a jurisdiction
Parameters
----------
url : str
URL of the website to validate.
Returns
-------
bool
``True`` if the website is the main website for the given
jurisdiction; ``False`` otherwise.
"""
url_validator = DTreeURLJurisdictionValidator(
jurisdiction, **self.kwargs
)
url_is_correct_jurisdiction = await url_validator.check(url)
if url_is_correct_jurisdiction:
return True
fl = AsyncFileLoader(
browser_semaphore=self.browser_semaphore,
**self.file_loader_kwargs,
)
try:
doc = await fl.fetch(url)
except KeyboardInterrupt:
raise
except Exception as e:
msg = "Encountered error of type %r while trying to validate %s"
err_type = type(e)
logger.exception(msg, err_type, url)
return False
if doc.empty:
return False
prompt = (
"Based on the website text below, is it reasonable to conclude "
f"that this webpage is the **main** {jurisdiction.type} website "
f"for {jurisdiction.full_name_the_prefixed}? "
"Please start your response with either 'Yes' or 'No' and briefly "
"explain your answer."
f'\n\n"""\n{doc.text}\n"""'
)
local_chat_llm_caller = LLMCaller(**self.kwargs)
out = await local_chat_llm_caller.call(
sys_msg=self.WEB_PAGE_CHECK_SYSTEM_MESSAGE,
content=prompt,
usage_sub_label=(
LLMUsageCategory.JURISDICTION_MAIN_WEBSITE_VALIDATION
),
)
return out.casefold().startswith("yes")
async def _validator_check_for_doc(validator, doc, score_thresh=0.9, **kwargs):
"""Apply a validator check to a doc's raw pages"""
outer_task_name = asyncio.current_task().get_name()
validation_checks = [
asyncio.create_task(
validator.check(text, **kwargs), name=outer_task_name
)
for text in doc.raw_pages
]
out = await asyncio.gather(*validation_checks)
score = _weighted_vote(out, doc)
doc.attrs[validator.META_SCORE_KEY] = score
logger.debug(
"%s is %.2f for doc from source %s (Pass: %s; threshold: %.2f)",
validator.META_SCORE_KEY,
score,
doc.attrs.get("source", "Unknown"),
str(score >= score_thresh),
score_thresh,
)
return score >= score_thresh
def _weighted_vote(out, doc):
"""Compute weighted average of responses based on text length"""
if not doc.raw_pages:
return 0
total = weights = 0
for verdict, text in zip(out, doc.raw_pages, strict=True):
if verdict is None:
continue
weight = len(text)
logger.debug("Weight=%d, Verdict=%d", weight, int(verdict))
weights += weight
total += verdict * weight
weights = max(weights, 1)
return total / weights