"""Ordinance document content Validation logic
These are primarily used to validate that a legal document applies to a
particular technology (e.g. Large Wind Energy Conversion Systems).
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from compass.llm.calling import ChatLLMCaller, StructuredLLMCaller
from compass.validation.graphs import setup_graph_correct_document_type
from compass.common import setup_async_decision_tree, run_async_tree
from compass.utilities.enums import LLMUsageCategory
logger = logging.getLogger(__name__)
[docs]
class ParseChunksWithMemory:
"""Check text chunks by sometimes looking at previous chunks
The idea behind this approach is that sometimes the context for a
setback or other ordinances is found in a previous chunk, so it may
be worthwhile (especially for validation purposes) to check a few
text chunks back for some validation pieces. In order to do this
semi-efficiently, we make use of a cache that's labeled "memory".
"""
def __init__(self, text_chunks, num_to_recall=2):
"""
Parameters
----------
text_chunks : list of str
List of strings, each of which represent a chunk of text.
The order of the strings should be the order of the text
chunks. This validator may refer to previous text chunks to
answer validation questions.
num_to_recall : int, optional
Number of chunks to check for each validation call. This
includes the original chunk! For example, if
`num_to_recall=2`, the validator will first check the chunk
at the requested index, and then the previous chunk as well.
By default, ``2``.
"""
self.text_chunks = text_chunks
self.num_to_recall = num_to_recall
self.memory = [{} for _ in text_chunks]
# fmt: off
def _inverted_mem(self, starting_ind):
"""Inverted memory"""
inverted_mem = self.memory[:starting_ind + 1:][::-1]
yield from inverted_mem[:self.num_to_recall]
# fmt: off
def _inverted_text(self, starting_ind):
"""Inverted text chunks"""
inverted_text = self.text_chunks[:starting_ind + 1:][::-1]
yield from inverted_text[:self.num_to_recall]
[docs]
async def parse_from_ind(self, ind, key, llm_call_callback):
"""Validate a chunk of text
Validation occurs by querying the LLM using the input prompt and
parsing the `key` from the response JSON. The prompt should
request that the key be a boolean output. If the key retrieved
from the LLM response is False, a number of previous text chunks
are checked as well, using the same prompt. This can be helpful
in cases where the answer to the validation prompt (e.g. does
this text pertain to a large WECS?) is only found in a previous
text chunk.
Parameters
----------
ind : int
Positive integer corresponding to the chunk index.
Must be less than `len(text_chunks)`.
key : str
A key expected in the JSON output of the LLM containing the
response for the validation question. This string will also
be used to format the system prompt before it is passed to
the LLM.
llm_call_callback : callable
Callable that takes a `key` and `text_chunk` as inputs and
returns a boolean indicating whether or not the text chunk
passes the validation check.
Returns
-------
bool
``True`` if the LLM returned ``True`` for this text chunk or
`num_to_recall-1` text chunks before it.
``False`` otherwise.
"""
logger.debug("Checking %r for ind %d", key, ind)
mem_text = zip(
self._inverted_mem(ind), self._inverted_text(ind), strict=False
)
for step, (mem, text) in enumerate(mem_text):
logger.debug("Mem at ind %d is %s", step, mem)
check = mem.get(key)
if check is None:
check = mem[key] = await llm_call_callback(key, text)
if check:
return check
return False
[docs]
class Heuristic(ABC):
"""Perform a heuristic check for mention of a technology in text"""
_GOOD_ACRONYM_CONTEXTS = [
" {acronym} ",
" {acronym}\n",
" {acronym}.",
"\n{acronym} ",
"\n{acronym}.",
"\n{acronym}\n",
"({acronym} ",
" {acronym})",
]
[docs]
def check(self, text, match_count_threshold=1):
"""Check for mention of a tech in text
This check first strips the text of any tech "look-alike" words
(e.g. "window", "windshield", etc for "wind" technology). Then,
it checks for particular keywords, acronyms, and phrases that
pertain to the tech in the text. If enough keywords are mentions
(as dictated by `match_count_threshold`), this check returns
``True``.
Parameters
----------
text : str
Input text that may or may not mention the technology of
interest.
match_count_threshold : int, optional
Number of keywords that must match for the text to pass this
heuristic check. Count must be strictly greater than this
value. By default, ``1``.
Returns
-------
bool
``True`` if the number of keywords/acronyms/phrases detected
exceeds the `match_count_threshold`.
"""
heuristics_text = self._convert_to_heuristics_text(text)
total_keyword_matches = self._count_single_keyword_matches(
heuristics_text
)
total_keyword_matches += self._count_acronym_matches(heuristics_text)
total_keyword_matches += self._count_phrase_matches(heuristics_text)
return total_keyword_matches > match_count_threshold
def _convert_to_heuristics_text(self, text):
"""Convert text for heuristic content parsing"""
heuristics_text = text.casefold()
for word in self.NOT_TECH_WORDS:
heuristics_text = heuristics_text.replace(word, "")
return heuristics_text
def _count_single_keyword_matches(self, heuristics_text):
"""Count number of good tech keywords that appear in text"""
return sum(
keyword in heuristics_text for keyword in self.GOOD_TECH_KEYWORDS
)
def _count_acronym_matches(self, heuristics_text):
"""Count number of good tech acronyms that appear in text"""
acronym_matches = 0
for context in self._GOOD_ACRONYM_CONTEXTS:
acronym_keywords = {
context.format(acronym=acronym)
for acronym in self.GOOD_TECH_ACRONYMS
}
acronym_matches = sum(
keyword in heuristics_text for keyword in acronym_keywords
)
if acronym_matches > 0:
break
return acronym_matches
def _count_phrase_matches(self, heuristics_text):
"""Count number of good tech phrases that appear in text"""
return sum(
all(keyword in heuristics_text for keyword in phrase.split(" "))
for phrase in self.GOOD_TECH_PHRASES
)
@property
@abstractmethod
def NOT_TECH_WORDS(self): # noqa: N802
"""iter: Iterable of words that don't pertain to the tech"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_KEYWORDS(self): # noqa: N802
"""iter: Iterable of keywords that pertain to the tech"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_ACRONYMS(self): # noqa: N802
"""iter: Iterable of acronyms that pertain to the tech"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_PHRASES(self): # noqa: N802
"""iter: Iterable of phrases that pertain to the tech"""
raise NotImplementedError
[docs]
class LegalTextValidator(StructuredLLMCaller):
"""Parse chunks to determine if they contain legal text"""
SYSTEM_MESSAGE = (
"You are an AI designed to classify text excerpts based on their "
"source type. The goal is to identify text that is extracted from "
"**legally binding regulations (such as zoning ordinances or "
"enforceable bans)** and filter out text that was extracted from "
"anything other than a legal statute for an existing jurisdiction."
)
def __init__(
self, *args, score_threshold=0.8, doc_is_from_ocr=False, **kwargs
):
"""
Parameters
----------
score_threshold : float, optional
Minimum fraction of text chunks that have to pass the legal
check for the whole document to be considered legal text.
By default, ``0.8``.
*args, **kwargs
Parameters to pass to the
:class:`~compass.llm.calling.StructuredLLMCaller`
initializer.
"""
super().__init__(*args, **kwargs)
self.score_threshold = score_threshold
self._legal_text_mem = []
self.doc_is_from_ocr = doc_is_from_ocr
@property
def is_legal_text(self):
"""bool: ``True`` if text was found to be from a legal source"""
if not self._legal_text_mem:
return False
score = sum(self._legal_text_mem) / len(self._legal_text_mem)
return score >= self.score_threshold
[docs]
async def check_chunk(self, chunk_parser, ind):
"""Check a chunk at a given ind to see if it contains legal text
Parameters
----------
chunk_parser : ParseChunksWithMemory
Instance of `ParseChunksWithMemory` that contains a
`parse_from_ind` method.
ind : int
Index of the chunk to check.
Returns
-------
bool
Boolean flag indicating whether or not the text in the chunk
resembles legal text.
"""
is_legal_text = await chunk_parser.parse_from_ind(
ind,
key="legal_text",
llm_call_callback=self._check_chunk_for_legal_text,
)
self._legal_text_mem.append(is_legal_text)
if is_legal_text:
logger.debug("Text at ind %d is legal text", ind)
else:
logger.debug("Text at ind %d is not legal text", ind)
return is_legal_text
async def _check_chunk_for_legal_text(self, key, text_chunk):
"""Call LLM on a chunk of text to check for legal text"""
chat_llm_caller = ChatLLMCaller(
llm_service=self.llm_service,
system_message=self.SYSTEM_MESSAGE.format(key=key),
usage_tracker=self.usage_tracker,
**self.kwargs,
)
tree = setup_async_decision_tree(
setup_graph_correct_document_type,
usage_sub_label=LLMUsageCategory.DOCUMENT_CONTENT_VALIDATION,
key=key,
text=text_chunk,
chat_llm_caller=chat_llm_caller,
doc_is_from_ocr=self.doc_is_from_ocr,
)
out = await run_async_tree(tree, response_as_json=True)
logger.debug("LLM response: %s", str(out))
return out.get(key, False)
[docs]
async def parse_by_chunks(
chunk_parser,
heuristic,
legal_text_validator,
callbacks=None,
min_chunks_to_process=3,
):
"""Parse text by chunks, passing to callbacks if it's legal text
This method goes through the chunks one by one, and passes them to
the callback parsers if the `legal_text_validator` check passes. If
`min_chunks_to_process` number of chunks fail the legal text check,
parsing is aborted.
Parameters
----------
chunk_parser : ParseChunksWithMemory
Instance of `ParseChunksWithMemory` that contains the attributes
`text_chunks` and `num_to_recall`. The chunks in the
`text_chunks` attribute will be iterated over.
heuristic : Heuristic
Instance of `Heuristic` with a `check` method. This should be a
fast check meant to quickly dispose of chunks of text. Any chunk
that fails this check will NOT be passed to the callback
parsers.
legal_text_validator : LegalTextValidator
Instance of `LegalTextValidator` that can be used to validate
each chunk for legal text.
callbacks : list, optional
List of async callbacks that take a `chunk_parser` and `index`
as inputs and return a boolean determining whether the text
chunk was parsed successfully or not. By default, ``None``,
which does not use any callbacks.
min_chunks_to_process : int, optional
Minimum number of chunks to process before aborting due to text
not being legal. By default, ``3``.
"""
passed_heuristic_mem = []
callbacks = callbacks or []
outer_task_name = asyncio.current_task().get_name()
for ind, text in enumerate(chunk_parser.text_chunks):
passed_heuristic_mem.append(heuristic.check(text))
if ind < min_chunks_to_process:
is_legal = await legal_text_validator.check_chunk(
chunk_parser, ind
)
if not is_legal: # don't bother checking this chunk
continue
# don't bother checking this document
elif not legal_text_validator.is_legal_text:
return
# hasn't passed heuristic, so don't pass it to callbacks
elif not any(passed_heuristic_mem[-chunk_parser.num_to_recall :]):
continue
logger.debug("Processing text at ind %d", ind)
logger.debug_to_file("Text:\n%s", text)
if not callbacks:
continue
cb_futures = [
asyncio.create_task(cb(chunk_parser, ind), name=outer_task_name)
for cb in callbacks
]
cb_results = await asyncio.gather(*cb_futures)
# mask this chunk if we got a good result - this avoids forcing
# the following chunk to be checked (it will only be checked if
# it itself passes the heuristic)
passed_heuristic_mem[-1] = not any(cb_results)