Module buildstock_query.query_core
Expand source code
import boto3
import contextlib
import pathlib
from pyathena.connection import Connection
from pyathena.error import OperationalError
from pyathena.sqlalchemy.base import AthenaDialect
import sqlalchemy as sa
from pyathena.pandas.async_cursor import AsyncPandasCursor
from pyathena.pandas.cursor import PandasCursor
import os
from typing import Union, Optional, Literal, Sequence
import typing
import time
import logging
from threading import Thread
from botocore.exceptions import ClientError
import pandas as pd
import datetime
import numpy as np
from collections import OrderedDict
import types
from buildstock_query.helpers import CachedFutureDf, AthenaFutureDf, DataExistsException, CustomCompiler
from buildstock_query.helpers import save_pickle, load_pickle, read_csv
from typing import TypedDict, NewType
from botocore.config import Config
import urllib3
from buildstock_query.schema.run_params import RunParams
from buildstock_query.db_schema.db_schema_model import DBSchema
from buildstock_query.schema.utilities import DBColType, AnyColType, AnyTableType, SALabel
from pydantic import validate_arguments
import hashlib
import toml
urllib3.disable_warnings()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
FUELS = ['electricity', 'natural_gas', 'propane', 'fuel_oil', 'coal', 'wood_cord', 'wood_pellets']
class QueryException(Exception):
pass
ExeId = NewType('ExeId', str)
class BatchQueryStatusMap(TypedDict):
to_submit_ids: list[int]
all_ids: list[int]
submitted_ids: list[int]
submitted_execution_ids: list[ExeId]
submitted_queries: list[str]
queries_futures: list[Union[CachedFutureDf, AthenaFutureDf]]
class BatchQueryReportMap(TypedDict):
submitted: int
running: int
pending: int
completed: int
failed: int
class QueryCore:
def __init__(self, *, params: RunParams
) -> None:
"""
Base class to run common Athena queries for BuildStock runs and download results as pandas dataFrame
Usually, you should just use BuildStockQuery. This class is useuful if you want to extend the functionality
for Athena tables that are not part of ResStock or ComStock runs.
Args:
workgroup (str): The workgroup for athena. The cost will be charged based on workgroup.
db_name (str): The athena database name
buildstock_type (str, optional): 'resstock' or 'comstock' runs. Defaults to 'resstock'
table_name (str or Union[str, tuple[str, Optional[str], Optional[str]]]): If a single string is provided,
say, 'mfm_run', then it must correspond to tables in athena named mfm_run_baseline and optionally
mfm_run_timeseries and mf_run_upgrades. Or, tuple of three elements can be privided for the table names
for baseline, timeseries and upgrade. Timeseries and upgrade can be None if no such table exist.
db_schema (str, optional): The database structure in Athena is different between ResStock and ComStock run.
It is also different between the version in OEDI and default version from BuildStockBatch. This argument
controls the assumed schema. Allowed values are 'resstock_default', 'resstock_oedi', 'comstock_default'
and 'comstock_oedi'. Defaults to 'resstock_default' for resstock and 'comstock_default' for comstock.
sample_weight (str, optional): Specify a custom sample_weight. Otherwise, the default is 1 for ComStock and
uses sample_weight in the run for ResStock.
region_name (str, optional): the AWS region where the database exists. Defaults to 'us-west-2'.
execution_history (str, optional): A temporary file to record which execution is run by the user,
to help stop them. Will use .execution_history if not supplied. Generally, not required to supply a
custom filename.
athena_query_reuse (bool, optional): When true, Athena will make use of its built-in 7 day query cache.
When false, it will not. Defaults to True. One use case to set this to False is when you have modified
the underlying s3 data or glue schema and want to make sure you are not using the cached results.
"""
logger.info(f"Loading {params.table_name} ...")
self.run_params = params
self.workgroup = params.workgroup
self.buildstock_type = params.buildstock_type
self._query_cache: dict[str, pd.DataFrame] = {} # {"query": query_result_df} to cache queries
self._session_queries: set[str] = set() # Set of all queries that is run in current session.
self._aws_s3 = boto3.client('s3')
self._aws_athena = boto3.client('athena', region_name=params.region_name)
self._aws_glue = boto3.client('glue', region_name=params.region_name)
self._conn = Connection(work_group=params.workgroup, region_name=params.region_name,
cursor_class=PandasCursor, schema_name=params.db_name,
config=Config(max_pool_connections=20))
self._async_conn = Connection(work_group=params.workgroup, region_name=params.region_name,
cursor_class=AsyncPandasCursor, schema_name=params.db_name,
config=Config(max_pool_connections=20))
self.db_name = params.db_name
self.region_name = params.region_name
self._tables: dict[str, sa.Table] = OrderedDict() # Internal record of tables
self._batch_query_status_map: dict[int, BatchQueryStatusMap] = {}
self._batch_query_id = 0
db_schema_file = os.path.join(os.path.dirname(__file__), 'db_schema',
f'{params.db_schema}.toml')
db_schema_dict = toml.load(db_schema_file)
self.db_schema = DBSchema.parse_obj(db_schema_dict)
self.db_col_name = self.db_schema.column_names
self.timestamp_column_name = self.db_col_name.timestamp
self.building_id_column_name = self.db_col_name.building_id
self.sample_weight = params.sample_weight_override if params.sample_weight_override is not None else \
self.db_col_name.sample_weight
self.table_name = params.table_name
self.cache_folder = pathlib.Path(params.cache_folder)
self.athena_query_reuse = params.athena_query_reuse
os.makedirs(self.cache_folder, exist_ok=True)
self._initialize_tables()
self._initialize_book_keeping(params.execution_history)
with contextlib.suppress(FileNotFoundError):
self.load_cache()
@staticmethod
def _get_compact_cache_name(table_name: str) -> str:
table_name = str(table_name)
if len(table_name) > 64:
return hashlib.sha256(table_name.encode()).hexdigest()
else:
return table_name
def _get_cache_file_path(self) -> pathlib.Path:
return self.cache_folder / f"{self._get_compact_cache_name(self.table_name)}_query_cache.pkl"
@validate_arguments
def load_cache(self, path: Optional[str] = None):
"""Read and update query cache from pickle file.
Args:
path (str, optional): The path to the pickle file. If not provided, reads from current directory.
"""
pickle_path = pathlib.Path(path) if path else self._get_cache_file_path()
before_count = len(self._query_cache)
saved_cache = load_pickle(pickle_path)
logger.info(f"{len(saved_cache)} queries cache read from {path}.")
self._query_cache.update(saved_cache)
self.last_saved_queries = set(saved_cache)
after_count = len(self._query_cache)
if diff := after_count - before_count:
logger.info(f"{diff} queries cache is updated.")
else:
logger.info("Cache already upto date.")
@validate_arguments
def save_cache(self, path: Optional[str] = None, trim_excess: bool = False):
"""Saves queries cache to a pickle file. It is good idea to run this afer making queries so that on the next
session these queries won't have to be run on Athena and can be directly loaded from the file.
Args:
path (str, optional): The path to the pickle file. If not provided, the file will be saved on the current
directory.
trim_excess (bool, optional): If true, any queries in the cache that is not run in current session will be
remved before saving it to file. This is useful if the cache has accumulated a bunch of stray queries over
several sessions that are no longer used. Defaults to False.
"""
cached_queries = set(self._query_cache)
if self.last_saved_queries == cached_queries:
logger.info("No new queries to save.")
return
pickle_path = pathlib.Path(path) if path else self._get_cache_file_path()
if trim_excess:
if excess_queries := [key for key in self._query_cache if key not in self._session_queries]:
for query in excess_queries:
del self._query_cache[query]
logger.info(f"{len(excess_queries)} excess queries removed from cache.")
self.last_saved_queries = cached_queries
save_pickle(pickle_path, self._query_cache)
logger.info(f"{len(self._query_cache)} queries cache saved to {pickle_path}")
def _initialize_tables(self):
self.bs_table, self.ts_table, self.up_table = self._get_tables(self.table_name)
self.bs_bldgid_column = self.bs_table.c[self.building_id_column_name]
if self.ts_table is not None:
self.timestamp_column = self.ts_table.c[self.timestamp_column_name]
self.ts_bldgid_column = self.ts_table.c[self.building_id_column_name]
if self.up_table is not None:
self.up_bldgid_column = self.up_table.c[self.building_id_column_name]
self.sample_wt = self._get_sample_weight(self.sample_weight)
def _get_sample_weight(self, sample_weight):
if not sample_weight:
return sa.literal(1)
elif isinstance(sample_weight, str):
try:
return self.bs_table.c[sample_weight]
except ValueError:
logger.error("Sample weight column not found. Using weight of 1.")
return sa.literal(1)
elif isinstance(sample_weight, (int, float)):
return sa.literal(sample_weight)
else:
raise ValueError("Invalid value for sample_weight")
@typing.overload
def _get_table(self, table_name: AnyTableType, missing_ok: Literal[True]) -> Optional[sa.Table]:
...
@typing.overload
def _get_table(self, table_name: AnyTableType, missing_ok: Literal[False] = False) -> sa.Table:
...
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True))
def _get_table(self, table_name: AnyTableType, missing_ok: bool = False) -> Optional[sa.Table]:
if not isinstance(table_name, str):
return table_name # already a table
try:
return self._tables.setdefault(table_name, sa.Table(table_name, self._meta, autoload_with=self._engine))
except sa.exc.NoSuchTableError: # type: ignore
if missing_ok:
logger.warning(f"No {table_name} table is present.")
return None
else:
raise
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True))
def _get_column(self, column_name: AnyColType,
table_name: Optional[AnyTableType] = None) -> DBColType:
if not isinstance(column_name, str):
return column_name # already a col
if table_name is not None:
valid_tables = [self._get_table(table_name)]
else:
valid_tables = []
for tbl in [self.bs_table, self.up_table, self.ts_table]:
if tbl is not None and column_name in tbl.columns:
valid_tables.append(tbl)
if not valid_tables:
valid_tables += [table for _, table in self._tables.items()
if column_name in table.columns]
if not valid_tables:
raise ValueError(f"Column {column_name} not found in any tables {[t.name for t in self._tables.values()]}")
if len(valid_tables) > 1:
logger.warning(
f"Column {column_name} found in multiple tables {[t.name for t in valid_tables]}."
f"Using {valid_tables[0].name}")
return valid_tables[0].c[column_name]
def _get_tables(self, table_name: Union[str, tuple[str, Optional[str], Optional[str]]]):
self._engine = self._create_athena_engine(region_name=self.region_name, database=self.db_name,
workgroup=self.workgroup)
self._meta = sa.MetaData(bind=self._engine)
if isinstance(table_name, str):
baseline_table = self._get_table(f'{table_name}{self.db_schema.table_suffix.baseline}')
ts_table = self._get_table(f'{table_name}{self.db_schema.table_suffix.timeseries}', missing_ok=True)
if self.db_schema.table_suffix.upgrades == self.db_schema.table_suffix.baseline:
upgrade_table = sa.select(baseline_table).where(
sa.cast(baseline_table.c['upgrade'], sa.String) != '0').alias('upgrade')
baseline_table = sa.select(baseline_table).where(
sa.cast(baseline_table.c['upgrade'], sa.String) == '0').alias('baseline')
else:
upgrade_table = self._get_table(f'{table_name}{self.db_schema.table_suffix.upgrades}', missing_ok=True)
else:
baseline_table = self._get_table(f'{table_name[0]}')
ts_table = self._get_table(f'{table_name[1]}', missing_ok=True) if table_name[1] else None
if table_name[2] == table_name[0]:
upgrade_table = sa.select(baseline_table).where(
sa.cast(baseline_table.c['upgrade'], sa.String) != '0').alias('upgrade')
baseline_table = sa.select(baseline_table).where(
sa.cast(baseline_table.c['upgrade'], sa.String) == '0').alias('baseline')
else:
upgrade_table = self._get_table(f'{table_name[2]}', missing_ok=True) if table_name[2] else None
return baseline_table, ts_table, upgrade_table
def _initialize_book_keeping(self, execution_history):
self._execution_history_file = execution_history or self.cache_folder / '.execution_history'
self.execution_cost = {'GB': 0, 'Dollars': 0} # Tracks the cost of current session. Only used for Athena query
self.seen_execution_ids = set() # set to prevent double counting same execution id
self.last_saved_queries = set()
if os.path.exists(self._execution_history_file):
with open(self._execution_history_file, 'r') as f:
existing_entries = f.readlines()
valid_entries = []
for entry in existing_entries:
with contextlib.suppress(ValueError, TypeError):
entry_time, _ = entry.split(',')
if time.time() - float(entry_time) < 24 * 60 * 60: # discard history if more than a day old
valid_entries += entry
with open(self._execution_history_file, 'w') as f:
f.writelines(valid_entries)
@property
def _execution_ids_history(self):
exe_ids: list[ExeId] = []
if os.path.exists(self._execution_history_file):
with open(self._execution_history_file, 'r') as f:
for line in f:
_, exe_id = line.split(',')
exe_ids.append(ExeId(exe_id.strip()))
return exe_ids
def _create_athena_engine(self, region_name: str, database: str, workgroup: str) -> sa.engine.Engine:
connect_args = {"cursor_class": PandasCursor, "work_group": workgroup}
engine = sa.create_engine(
f"awsathena+rest://:@athena.{region_name}.amazonaws.com:443/{database}", connect_args=connect_args
)
return engine
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True))
def delete_table(self, table_name: str):
"""
Function to delete athena table.
:param table_name: Athena table name
:return:
"""
delete_table_query = f"""DROP TABLE {self.db_name}.{table_name};"""
result, reason = self.execute_raw(delete_table_query)
if result.upper() == "SUCCEEDED":
return "SUCCEEDED"
else:
raise QueryException(f"Deleting it failed. Reason: {reason}")
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True))
def add_table(self, table_name: str, table_df: pd.DataFrame,
s3_bucket: str, s3_prefix: str, override: bool = False):
"""
Function to add a table in s3.
:param table_name: The name of the table
:param table_df: The pandas dataframe to use as table data
:param s3_bucket: s3 bucket name
:param s3_prefix: s3 prefix to save the table to.
:param override: Whether to override eixsting table.
:return:
"""
s3_location = s3_bucket + '/' + s3_prefix
s3_data = self._aws_s3.list_objects(Bucket=s3_bucket, Prefix=f'{s3_prefix}/{table_name}')
if 'Contents' in s3_data and override is False:
raise DataExistsException("Table already exists", f's3://{s3_location}/{table_name}/{table_name}.csv')
if 'Contents' in s3_data:
existing_objects = [{'Key': el['Key']} for el in s3_data['Contents']]
print(f"The following existing objects is being delete and replaced: {existing_objects}")
print(f"Saving s3://{s3_location}/{table_name}/{table_name}.parquet)")
self._aws_s3.delete_objects(Bucket=s3_bucket, Delete={"Objects": existing_objects})
print(f"Saving factors to s3 in s3://{s3_location}/{table_name}/{table_name}.parquet")
# table_df.to_parquet(f's3://{s3_location}/{table_name}/{table_name}.parquet', index=False)
self._aws_s3.put_object(Body=table_df.to_parquet(index=False), Bucket=s3_bucket,
Key=f"{s3_prefix}/{table_name}/{table_name}.parquet")
print("Saving Done.")
format_list = []
for column_name, dtype in table_df.dtypes.items():
if np.issubdtype(dtype, np.integer):
col_type = "int"
elif np.issubdtype(dtype, np.floating):
col_type = "double"
elif np.issubdtype(dtype, np.datetime64):
col_type = "timestamp"
else:
col_type = "string"
format_list.append(f"`{column_name}` {col_type}")
column_formats = ",".join(format_list)
table_create_query = f"""
CREATE EXTERNAL TABLE {self.db_name}.{table_name} ({column_formats})
STORED AS PARQUET
LOCATION 's3://{s3_location}/{table_name}/'
TBLPROPERTIES ('has_encrypted_data'='false');
"""
print(f"Running create table query.\n {table_create_query}")
result, reason = self.execute_raw(table_create_query)
if result.lower() == "failed" and 'alreadyexists' in reason.lower():
if not override:
existing_data = read_csv(f's3://{s3_location}/{table_name}/{table_name}.csv')
raise DataExistsException("Table already exists", existing_data)
print(f"There was existing table {table_name} in Athena which was deleted and recreated.")
delete_table_query = f"""
DROP TABLE {self.db_name}.{table_name};
"""
result, reason = self.execute_raw(delete_table_query)
if result.upper() != "SUCCEEDED":
raise QueryException(f"There was an existing table named {table_name}. Deleting it failed."
f" Reason: {reason}")
result, reason = self.execute_raw(table_create_query)
if result.upper() == "SUCCEEDED":
return "SUCCEEDED"
else:
raise QueryException(f"There was an existing table named {table_name} which is now successfully "
f"deleted but new table failed to be created. Reason: {reason}")
elif result.upper() == "SUCCEEDED":
return "SUCCEEDED"
else:
raise QueryException(f"Failed to create the table. Reason: {reason}")
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True))
def execute_raw(self, query, db: Optional[str] = None, run_async: bool = False):
"""
Directly executes the supplied query in Athena.
:param query:
:param db:
:param run_async:
:return:
"""
if not db:
db = self.db_name
response = self._aws_athena.start_query_execution(
QueryString=query,
QueryExecutionContext={
'Database': db
},
WorkGroup=self.workgroup)
query_execution_id = ExeId(response['QueryExecutionId'])
if run_async:
return query_execution_id
start_time = time.time()
while time.time() - start_time < 30*60: # 30 minute timeout
query_stat = self._aws_athena.get_query_execution(QueryExecutionId=query_execution_id)
if query_stat['QueryExecution']['Status']['State'].lower() not in ['pending', 'running', 'queued']:
reason = query_stat['QueryExecution']['Status'].get('StateChangeReason', '')
return query_stat['QueryExecution']['Status']['State'], reason
time.sleep(1)
raise TimeoutError("Query failed to complete within 30 mins.")
def _save_execution_id(self, execution_id):
with open(self._execution_history_file, 'a') as f:
f.write(f'{time.time()},{execution_id}\n')
def _log_execution_cost(self, execution_id: ExeId):
if execution_id == "CACHED":
# Can't log cost for cached query
return
res = self._aws_athena.get_query_execution(QueryExecutionId=execution_id)
scanned_GB = res['QueryExecution']['Statistics']['DataScannedInBytes'] / 1e9
cost = scanned_GB * 5 / 1e3 # 5$ per TB scanned
if execution_id not in self.seen_execution_ids:
self.execution_cost['Dollars'] += cost
self.execution_cost['GB'] += scanned_GB
self.seen_execution_ids.add(execution_id)
logger.info(f"{execution_id} cost {scanned_GB:.1f} GB (${cost:.1f}). Session total:"
f" {self.execution_cost['GB']:.1f} GB (${self.execution_cost['Dollars']:.1f})")
def _compile(self, query) -> str:
compiled_query = CustomCompiler(AthenaDialect(), query).process(query, literal_binds=True)
return compiled_query
@typing.overload
def execute(self, query, *, run_async: Literal[False] = False) -> pd.DataFrame:
...
@typing.overload
def execute(self, query, *,
run_async: Literal[True],
) -> Union[tuple[Literal["CACHED"], CachedFutureDf], tuple[ExeId, AthenaFutureDf]]:
...
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True))
def execute(self, query, run_async: bool = False) -> Union[pd.DataFrame, tuple[Literal["CACHED"], CachedFutureDf],
tuple[ExeId, AthenaFutureDf]]:
"""
Executes a query
Args:
query: The SQL query to run in Athena
run_async: Whether to wait until the query completes (run_async=False) or return immediately
(run_async=True).
Returns:
if run_async is False, returns the results dataframe.
if run_async is True, returns the query_execution_id, futures
"""
if not isinstance(query, str):
query = self._compile(query)
self._session_queries.add(query)
if run_async:
if query in self._query_cache:
return "CACHED", CachedFutureDf(self._query_cache[query].copy())
# in case of asynchronous run, you get the execution id and futures object
exe_id, result_future = self._async_conn.cursor().execute(query,
result_reuse_enable=self.athena_query_reuse,
result_reuse_minutes=60 * 24 * 7,
na_values=['']) # type: ignore
exe_id = ExeId(exe_id)
def get_pandas(future):
res = future.result()
if res.state != 'SUCCEEDED':
raise OperationalError(f"{res.state}: {res.state_change_reason}")
if query in self._query_cache:
return self._query_cache[query]
return res.as_pandas()
result_future.as_pandas = types.MethodType(get_pandas, result_future)
result_future.add_done_callback(lambda f: self._query_cache.update({query: f.as_pandas()}))
self._save_execution_id(exe_id)
return exe_id, AthenaFutureDf(result_future)
else:
if query not in self._query_cache:
self._query_cache[query] = self._conn.cursor().execute(query,
result_reuse_enable=self.athena_query_reuse,
result_reuse_minutes=60 * 24 * 7,
).as_pandas()
return self._query_cache[query].copy()
def print_all_batch_query_status(self) -> None:
"""Prints the status of all batch queries.
"""
for count in self._batch_query_status_map.keys():
print(f'Query {count}: {self.get_batch_query_report(count)}\n')
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True))
def stop_batch_query(self, batch_id: int) -> None:
"""
Stops all the queries running under a batch query
Args:
batch_id: The batch_id of the batch_query. Returned by :py:sumbit_batch_query
Returns:
None
"""
if batch_id not in self._batch_query_status_map:
raise ValueError("Batch id not found")
self._batch_query_status_map[batch_id]['to_submit_ids'].clear()
for exec_id in self._batch_query_status_map[batch_id]['submitted_execution_ids']:
self.stop_query(exec_id)
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True))
def get_failed_queries(self, batch_id: int) -> tuple[Sequence[ExeId], Sequence[str]]:
"""_summary_
Args:
batch_id (int): Batch query id returned by :py:sumbit_batch_query
Returns:
_type_: tuple of list of failed query execution ids and list of failed queries
"""
stats = self._batch_query_status_map.get(batch_id, None)
failed_query_ids: list[ExeId] = []
failed_queries: list[str] = []
if stats:
for i, exe_id in enumerate(stats['submitted_execution_ids']):
completion_stat = self.get_query_status(exe_id)
if completion_stat in ['FAILED', 'CANCELLED']:
failed_query_ids.append(exe_id)
failed_queries.append(stats['submitted_queries'][i])
return failed_query_ids, failed_queries
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True))
def print_failed_query_errors(self, batch_id: int) -> None:
"""Print the error messages for all queries that failed in batch query.
Args:
batch_id (int): Batch query id
"""
failed_ids, failed_queries = self.get_failed_queries(batch_id)
for exe_id, query in zip(failed_ids, failed_queries):
print(f"Query id: {exe_id}. \n Query string: {query}. Query Ended with: {self.get_query_status(exe_id)}"
f"\nError: {self.get_query_error(exe_id)}\n")
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True))
def get_ids_for_failed_queries(self, batch_id: int) -> Sequence[str]:
"""Returns the list of execution ids for failed queries in batch query.
Args:
batch_id (int): batch query id
Returns:
Sequence[str]: List of failed execution ids.
"""
failed_ids = []
for i, exe_id in enumerate(self._batch_query_status_map[batch_id]['submitted_execution_ids']):
completion_stat = self.get_query_status(exe_id)
if completion_stat in ['FAILED', 'CANCELLED']:
failed_ids.append(exe_id)
return failed_ids
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True))
def get_batch_query_report(self, batch_id: int) -> BatchQueryReportMap:
"""
Returns the status of the queries running under a batch query.
Args:
batch_id: The batch_id of the batch_query.
Returns:
A dictionary detailing status of the queries.
"""
if not (stats := self._batch_query_status_map.get(batch_id, None)):
raise ValueError(f"{batch_id=} not found.")
success_count = 0
fail_count = 0
running_count = 0
other = 0
for exe_id in stats['submitted_execution_ids']:
if exe_id == 'CACHED':
completion_stat = "SUCCEEDED"
else:
completion_stat = self.get_query_status(exe_id)
if completion_stat == 'RUNNING':
running_count += 1
elif completion_stat == 'SUCCEEDED':
success_count += 1
elif completion_stat in ['FAILED', 'CANCELLED']:
fail_count += 1
else:
# for example: QUEUED
other += 1
result: BatchQueryReportMap = {'submitted': len(stats['submitted_ids']),
'running': running_count,
'pending': len(stats['to_submit_ids']) + other,
'completed': success_count,
'failed': fail_count
}
return result
@validate_arguments
def did_batch_query_complete(self, batch_id: int):
"""
Checks if all the queries in a batch query has completed or not.
Args:
batch_id: The batch_id for the batch_query
Returns:
True or False
"""
status = self.get_batch_query_report(batch_id)
if status['pending'] > 0 or status['running'] > 0:
return False
else:
return True
@validate_arguments
def wait_for_batch_query(self, batch_id: int):
"""Waits until batch query completes.
Args:
batch_id (int): The batch query id.
"""
sleep_time = 0.5 # start here and keep doubling until max_sleep_time
max_sleep_time = 20
while True:
last_time = time.time()
last_report = None
report = self.get_batch_query_report(batch_id)
if time.time() - last_time > 60 or last_report is None or report != last_report:
logger.info(report)
last_report = report
last_time = time.time()
if report['pending'] == 0 and report['running'] == 0:
break
time.sleep(sleep_time)
sleep_time = min(sleep_time * 2, max_sleep_time)
@typing.overload
def get_batch_query_result(self, batch_id: int, *, no_block: bool = False,
combine: Literal[True] = True) -> pd.DataFrame:
...
@typing.overload
def get_batch_query_result(self, batch_id: int, *, no_block: bool = False,
combine: Literal[False]) -> list[pd.DataFrame]:
...
@validate_arguments
def get_batch_query_result(self, batch_id: int, *, combine: bool = True, no_block: bool = False):
"""
Concatenates and returns the results of all the queries of a batchquery
Args:
batch_id (int): The batch_id for the batch_query
no_block (bool): Whether to wait until all queries have completed or return immediately. If you use
no_block = true and the batch hasn't completed, it will throw BatchStillRunning exception.
combine: Whether to combine the individual query result into a single dataframe
Returns:
The concatenated dataframe of the results of all the queries in a batch query.
"""
if no_block and self.did_batch_query_complete(batch_id) is False:
raise QueryException('Batch query not completed yet.')
self.wait_for_batch_query(batch_id)
logger.info("Batch query completed. ")
report = self.get_batch_query_report(batch_id)
query_exe_ids = self._batch_query_status_map[batch_id]['submitted_execution_ids']
query_futures = self._batch_query_status_map[batch_id]['queries_futures']
if report['failed'] > 0:
logger.warning(f"{report['failed']} queries failed. Redoing them")
failed_ids, failed_queries = self.get_failed_queries(batch_id)
new_batch_id = self.submit_batch_query(failed_queries)
new_exe_ids = self._batch_query_status_map[new_batch_id]['submitted_execution_ids']
self.wait_for_batch_query(new_batch_id)
new_exe_ids_map = {entry[0]: entry[1] for entry in zip(failed_ids, new_exe_ids)}
new_report = self.get_batch_query_report(new_batch_id)
if new_report['failed'] > 0:
self.print_failed_query_errors(new_batch_id)
raise QueryException("Queries failed again. Sorry!")
logger.info("The queries succeeded this time. Gathering all the results.")
# replace the old failed exe_ids with new successful exe_ids
for indx, old_exe_id in enumerate(query_exe_ids):
query_exe_ids[indx] = new_exe_ids_map.get(old_exe_id, old_exe_id)
if len(query_exe_ids) == 0:
raise ValueError("No query was submitted successfully")
res_df_array: list[pd.DataFrame] = []
for index, exe_id in enumerate(query_exe_ids):
df = query_futures[index].as_pandas().copy()
if combine:
if len(df) > 0:
df['query_id'] = index
logger.info(f"Got result from Query [{index}] ({exe_id})")
self._log_execution_cost(exe_id)
res_df_array.append(df)
if not combine:
return res_df_array
logger.info("Concatenating the results.")
# return res_df_array
return pd.concat(res_df_array)
@validate_arguments
def submit_batch_query(self, queries: Sequence[str]):
"""
Submit multiple related queries
Args:
queries: List of queries to submit. Setting `get_query_only` flag while making calls to aggregation
functions is easiest way to obtain queries.
Returns:
An integer representing the batch_query id. The id can be used with other batch_query functions.
"""
queries = list(queries)
to_submit_ids = list(range(len(queries)))
id_list = list(to_submit_ids) # make a copy
submitted_ids: list[int] = []
submitted_execution_ids: list[ExeId] = []
submitted_queries: list[str] = []
queries_futures: list = []
self._batch_query_id += 1
batch_query_id = self._batch_query_id
self._batch_query_status_map[batch_query_id] = {'to_submit_ids': to_submit_ids,
'all_ids': list(id_list),
'submitted_ids': submitted_ids,
'submitted_execution_ids': submitted_execution_ids,
'submitted_queries': submitted_queries,
'queries_futures': queries_futures
}
def run_queries():
while to_submit_ids:
current_id = to_submit_ids[0] # get the first one
current_query = queries[0]
try:
execution_id, future = self.execute(current_query, run_async=True)
logger.info(f"Submitted queries[{current_id}] ({execution_id})")
to_submit_ids.pop(0) # if query queued successfully, remove it from the list
queries.pop(0)
submitted_ids.append(current_id)
submitted_execution_ids.append(ExeId(execution_id))
submitted_queries.append(current_query)
queries_futures.append(future)
except ClientError as e:
if e.response['Error']['Code'] == 'TooManyRequestsException':
logger.info("Athena complained about too many requests. Waiting for a minute.")
time.sleep(60) # wait for a minute before submitting another query
elif e.response['Error']['Code'] == 'InvalidRequestException':
logger.info(f"Queries[{current_id}] is Invalid: {e.response['Message']} \n {current_query}")
to_submit_ids.pop(0) # query failed, so remove it from the list
queries.pop(0)
raise
else:
raise
query_runner = Thread(target=run_queries)
query_runner.start()
return batch_query_id
def _get_query_result(self, query_id):
return self.get_athena_query_result(execution_id=query_id)
@validate_arguments
def get_athena_query_result(self, execution_id: ExeId, timeout_minutes: int = 30) -> pd.DataFrame:
"""Returns the query result
Args:
execution_id (str): Query execution id.
timeout_minutes (int, optional): Timeout in minutes to wait for query to finish. Defaults to 30.
Raises:
QueryException: If query fails for some reason.
Returns:
pd.DataFrame: Query result as dataframe.
"""
t = time.time()
while time.time() - t < timeout_minutes * 60:
stat = self.get_query_status(execution_id)
if stat.upper() == 'SUCCEEDED':
result = self.get_result_from_s3(execution_id)
self._log_execution_cost(execution_id)
return result
elif stat.upper() == 'FAILED':
error = self.get_query_error(execution_id)
raise QueryException(error)
else:
logger.info(f"Query status is {stat}")
time.sleep(30)
raise QueryException(f'Query timed-out. {self.get_query_status(execution_id)}')
@validate_arguments
def get_result_from_s3(self, query_execution_id: ExeId) -> pd.DataFrame:
"""Returns query result from s3 location.
Args:
query_execution_id (str): The query execution ID
Raises:
QueryException: If query had failed.
Returns:
pd.DataFrame: The query result.
"""
query_status = self.get_query_status(query_execution_id)
if query_status == 'SUCCEEDED':
path = self.get_query_output_location(query_execution_id)
bucket = path.split('/')[2]
key = '/'.join(path.split('/')[3:])
response = self._aws_s3.get_object(Bucket=bucket, Key=key)
df = read_csv(response['Body'])
return df
# If failed, return error message
elif query_status == 'FAILED':
raise QueryException(self.get_query_error(query_execution_id))
elif query_status in ['RUNNING', 'QUEUED', 'PENDING']:
raise QueryException(f"Query still {query_status}")
else:
raise QueryException(f"Query has unkown status {query_status}")
@validate_arguments
def get_query_output_location(self, query_id: ExeId) -> str:
"""Get query output location in s3.
Args:
query_id (str): Query execution id.
Returns:
str: The query location in s3.
"""
stat = self._aws_athena.get_query_execution(QueryExecutionId=query_id)
output_path = stat['QueryExecution']['ResultConfiguration']['OutputLocation']
return output_path
@validate_arguments
def get_query_status(self, query_id: ExeId) -> str:
"""Get status of the query
Args:
query_id (str): Query execution id
Returns:
str: Status of the query.
"""
stat = self._aws_athena.get_query_execution(QueryExecutionId=query_id)
return stat['QueryExecution']['Status']['State']
@validate_arguments
def get_query_error(self, query_id: ExeId) -> str:
"""Returns the error message if query has failed.
Args:
query_id (str): Query execution id.
Returns:
str: Error message for the query.
"""
stat = self._aws_athena.get_query_execution(QueryExecutionId=query_id)
return stat['QueryExecution']['Status']['StateChangeReason']
def get_all_running_queries(self) -> list[ExeId]:
"""
Gives the list of all running queries (for this instance)
Return:
List of query execution ids of all the queries that are currently running in Athena.
"""
exe_ids = self._aws_athena.list_query_executions(WorkGroup=self.workgroup)['QueryExecutionIds']
exe_ids = [ExeId(i) for i in exe_ids]
running_ids = [i for i in exe_ids if i in self._execution_ids_history and
self.get_query_status(i) == "RUNNING"]
return running_ids
def stop_all_queries(self) -> None:
"""
Stops all queries that are running in Athena for this instance.
Returns:
Nothing
"""
for count, stat in self._batch_query_status_map.items():
stat['to_submit_ids'].clear()
running_ids = self.get_all_running_queries()
for i in running_ids:
self.stop_query(execution_id=i)
logger.info(f"Stopped {len(running_ids)} queries")
@validate_arguments
def stop_query(self, execution_id: ExeId) -> str:
"""
Stops a running query.
Args:
execution_id: The execution id of the query being run.
Returns:
"""
return self._aws_athena.stop_query_execution(QueryExecutionId=execution_id)
@validate_arguments(config=dict(arbitrary_types_allowed=True))
def get_cols(self, table: AnyTableType, fuel_type=None) -> Sequence[DBColType]:
"""
Returns the columns of for a particular table.
Args:
table: Name of the table. One of 'baseline' or 'timeseries'
fuel_type: Get only the columns for this fuel_type ('electricity', 'gas' etc)
Returns:
A list of column names as a list of strings.
"""
table = self._get_table(table)
if table == self.ts_table and self.ts_table is not None:
cols = [c for c in self.ts_table.columns]
if fuel_type:
cols = [c for c in cols if c.name not in [self.ts_bldgid_column.name, self.timestamp_column.name]]
cols = [c for c in cols if fuel_type in c.name]
return cols
elif table in ['baseline', 'bs']:
cols = [c for c in self.bs_table.columns]
if fuel_type:
cols = [c for c in cols if 'simulation_output_report' in c.name]
cols = [c for c in cols if fuel_type in c.name]
return cols
else:
tbl = self._get_table(table)
return [col for col in tbl.columns]
def _simple_label(self, label: str):
label = label.removeprefix(self.db_schema.column_prefix.characteristics)
label = label.removeprefix(self.db_schema.column_prefix.output)
return label
def _add_restrict(self, query, restrict, bs_only=False):
if not restrict:
return query
where_clauses = []
for col_str, criteria in restrict:
col = self._get_column(col_str, table_name=self.bs_table) if bs_only else self._get_column(col_str)
if isinstance(criteria, (list, tuple)):
if len(criteria) > 1:
where_clauses.append(self._get_column(col).in_(criteria))
continue
else:
criteria = criteria[0]
where_clauses.append(col == criteria)
query = query.where(*where_clauses)
return query
def _get_name(self, col):
if isinstance(col, tuple):
return col[1]
if isinstance(col, str):
return col
if isinstance(col, (sa.Column, SALabel)):
return col.name
raise ValueError(f"Can't get name for {col} of type {type(col)}")
def _add_join(self, query, join_list):
for new_table_name, baseline_column_name, new_column_name in join_list:
new_tbl = self._get_table(new_table_name)
baseline_column = self._get_column(baseline_column_name, table_name=self.bs_table)
new_column = self._get_column(new_column_name, table_name=new_tbl)
query = query.join(new_tbl, baseline_column == new_column)
return query
def _add_group_by(self, query, group_by_selection):
if group_by_selection:
selected_cols = list(query.selected_columns)
a = [sa.text(str(selected_cols.index(g) + 1)) for g in group_by_selection]
query = query.group_by(*a)
return query
def _add_order_by(self, query, order_by_selection):
if order_by_selection:
selected_cols = list(query.selected_columns)
a = [sa.text(str(selected_cols.index(g) + 1)) for g in order_by_selection]
query = query.order_by(*a)
return query
def _get_weight(self, weights):
total_weight = self.sample_wt
for weight_col in weights:
if isinstance(weight_col, tuple):
tbl = self._get_table(weight_col[1])
total_weight *= tbl.c[weight_col[0]]
else:
total_weight *= self._get_column(weight_col)
return total_weight
def delete_everything(self):
"""Deletes the athena tables and data in s3 for the run.
"""
info = self._aws_glue.get_table(DatabaseName=self.db_name, Name=self.bs_table.name)
self.pth = pathlib.Path(info['Table']['StorageDescriptor']['Location']).parent
tables_to_delete = [self.bs_table.name]
if self.ts_table is not None:
tables_to_delete.append(self.ts_table.name)
if self.up_table is not None:
tables_to_delete.append(self.up_table.name)
print(f"Will delete the following tables {tables_to_delete} and the {self.pth} folder")
while True:
curtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
confirm = input(f"Enter {curtime} to confirm.")
if confirm == "":
print("Abandoned the idea.")
break
if confirm != curtime:
print(f"Please pass {curtime} as confirmation to confirm you want to delete everything.")
continue
self._aws_glue.batch_delete_table(DatabaseName=self.db_name, TablesToDelete=tables_to_delete)
print("Deleted the table from athena, now will delete the data in s3")
s3 = boto3.resource('s3')
bucket = s3.Bucket(self.pth.parts[1]) # type: ignore
prefix = str(pathlib.Path(*self.pth.parts[2:]))
total_files = [file.key for file in bucket.objects.filter(Prefix=prefix)]
print(f"There are {len(total_files)} files to be deleted. Deleting them now")
bucket.objects.filter(Prefix=prefix).delete()
print("Delete from s3 completed")
break
Classes
class BatchQueryReportMap (*args, **kwargs)
-
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
Expand source code
class BatchQueryReportMap(TypedDict): submitted: int running: int pending: int completed: int failed: int
Ancestors
- builtins.dict
Class variables
var completed : int
var failed : int
var pending : int
var running : int
var submitted : int
class BatchQueryStatusMap (*args, **kwargs)
-
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
Expand source code
class BatchQueryStatusMap(TypedDict): to_submit_ids: list[int] all_ids: list[int] submitted_ids: list[int] submitted_execution_ids: list[ExeId] submitted_queries: list[str] queries_futures: list[Union[CachedFutureDf, AthenaFutureDf]]
Ancestors
- builtins.dict
Class variables
var all_ids : list[int]
var queries_futures : list[typing.Union[CachedFutureDf, AthenaFutureDf]]
var submitted_execution_ids : list[buildstock_query.query_core.ExeId]
var submitted_ids : list[int]
var submitted_queries : list[str]
var to_submit_ids : list[int]
class QueryCore (*, params: RunParams)
-
Base class to run common Athena queries for BuildStock runs and download results as pandas dataFrame Usually, you should just use BuildStockQuery. This class is useuful if you want to extend the functionality for Athena tables that are not part of ResStock or ComStock runs.
Args
workgroup
:str
- The workgroup for athena. The cost will be charged based on workgroup.
db_name
:str
- The athena database name
buildstock_type
:str
, optional- 'resstock' or 'comstock' runs. Defaults to 'resstock'
table_name
:str
orUnion[str, tuple[str, Optional[str], Optional[str]]]
- If a single string is provided,
- say, 'mfm_run', then it must correspond to tables in athena named mfm_run_baseline and optionally
- mfm_run_timeseries and mf_run_upgrades. Or, tuple of three elements can be privided for the table names
- for baseline, timeseries and upgrade. Timeseries and upgrade can be None if no such table exist.
db_schema
:str
, optional- The database structure in Athena is different between ResStock and ComStock run. It is also different between the version in OEDI and default version from BuildStockBatch. This argument controls the assumed schema. Allowed values are 'resstock_default', 'resstock_oedi', 'comstock_default' and 'comstock_oedi'. Defaults to 'resstock_default' for resstock and 'comstock_default' for comstock.
sample_weight
:str
, optional- Specify a custom sample_weight. Otherwise, the default is 1 for ComStock and uses sample_weight in the run for ResStock.
region_name
:str
, optional- the AWS region where the database exists. Defaults to 'us-west-2'.
execution_history
:str
, optional- A temporary file to record which execution is run by the user, to help stop them. Will use .execution_history if not supplied. Generally, not required to supply a custom filename.
athena_query_reuse
:bool
, optional- When true, Athena will make use of its built-in 7 day query cache. When false, it will not. Defaults to True. One use case to set this to False is when you have modified the underlying s3 data or glue schema and want to make sure you are not using the cached results.
Expand source code
class QueryCore: def __init__(self, *, params: RunParams ) -> None: """ Base class to run common Athena queries for BuildStock runs and download results as pandas dataFrame Usually, you should just use BuildStockQuery. This class is useuful if you want to extend the functionality for Athena tables that are not part of ResStock or ComStock runs. Args: workgroup (str): The workgroup for athena. The cost will be charged based on workgroup. db_name (str): The athena database name buildstock_type (str, optional): 'resstock' or 'comstock' runs. Defaults to 'resstock' table_name (str or Union[str, tuple[str, Optional[str], Optional[str]]]): If a single string is provided, say, 'mfm_run', then it must correspond to tables in athena named mfm_run_baseline and optionally mfm_run_timeseries and mf_run_upgrades. Or, tuple of three elements can be privided for the table names for baseline, timeseries and upgrade. Timeseries and upgrade can be None if no such table exist. db_schema (str, optional): The database structure in Athena is different between ResStock and ComStock run. It is also different between the version in OEDI and default version from BuildStockBatch. This argument controls the assumed schema. Allowed values are 'resstock_default', 'resstock_oedi', 'comstock_default' and 'comstock_oedi'. Defaults to 'resstock_default' for resstock and 'comstock_default' for comstock. sample_weight (str, optional): Specify a custom sample_weight. Otherwise, the default is 1 for ComStock and uses sample_weight in the run for ResStock. region_name (str, optional): the AWS region where the database exists. Defaults to 'us-west-2'. execution_history (str, optional): A temporary file to record which execution is run by the user, to help stop them. Will use .execution_history if not supplied. Generally, not required to supply a custom filename. athena_query_reuse (bool, optional): When true, Athena will make use of its built-in 7 day query cache. When false, it will not. Defaults to True. One use case to set this to False is when you have modified the underlying s3 data or glue schema and want to make sure you are not using the cached results. """ logger.info(f"Loading {params.table_name} ...") self.run_params = params self.workgroup = params.workgroup self.buildstock_type = params.buildstock_type self._query_cache: dict[str, pd.DataFrame] = {} # {"query": query_result_df} to cache queries self._session_queries: set[str] = set() # Set of all queries that is run in current session. self._aws_s3 = boto3.client('s3') self._aws_athena = boto3.client('athena', region_name=params.region_name) self._aws_glue = boto3.client('glue', region_name=params.region_name) self._conn = Connection(work_group=params.workgroup, region_name=params.region_name, cursor_class=PandasCursor, schema_name=params.db_name, config=Config(max_pool_connections=20)) self._async_conn = Connection(work_group=params.workgroup, region_name=params.region_name, cursor_class=AsyncPandasCursor, schema_name=params.db_name, config=Config(max_pool_connections=20)) self.db_name = params.db_name self.region_name = params.region_name self._tables: dict[str, sa.Table] = OrderedDict() # Internal record of tables self._batch_query_status_map: dict[int, BatchQueryStatusMap] = {} self._batch_query_id = 0 db_schema_file = os.path.join(os.path.dirname(__file__), 'db_schema', f'{params.db_schema}.toml') db_schema_dict = toml.load(db_schema_file) self.db_schema = DBSchema.parse_obj(db_schema_dict) self.db_col_name = self.db_schema.column_names self.timestamp_column_name = self.db_col_name.timestamp self.building_id_column_name = self.db_col_name.building_id self.sample_weight = params.sample_weight_override if params.sample_weight_override is not None else \ self.db_col_name.sample_weight self.table_name = params.table_name self.cache_folder = pathlib.Path(params.cache_folder) self.athena_query_reuse = params.athena_query_reuse os.makedirs(self.cache_folder, exist_ok=True) self._initialize_tables() self._initialize_book_keeping(params.execution_history) with contextlib.suppress(FileNotFoundError): self.load_cache() @staticmethod def _get_compact_cache_name(table_name: str) -> str: table_name = str(table_name) if len(table_name) > 64: return hashlib.sha256(table_name.encode()).hexdigest() else: return table_name def _get_cache_file_path(self) -> pathlib.Path: return self.cache_folder / f"{self._get_compact_cache_name(self.table_name)}_query_cache.pkl" @validate_arguments def load_cache(self, path: Optional[str] = None): """Read and update query cache from pickle file. Args: path (str, optional): The path to the pickle file. If not provided, reads from current directory. """ pickle_path = pathlib.Path(path) if path else self._get_cache_file_path() before_count = len(self._query_cache) saved_cache = load_pickle(pickle_path) logger.info(f"{len(saved_cache)} queries cache read from {path}.") self._query_cache.update(saved_cache) self.last_saved_queries = set(saved_cache) after_count = len(self._query_cache) if diff := after_count - before_count: logger.info(f"{diff} queries cache is updated.") else: logger.info("Cache already upto date.") @validate_arguments def save_cache(self, path: Optional[str] = None, trim_excess: bool = False): """Saves queries cache to a pickle file. It is good idea to run this afer making queries so that on the next session these queries won't have to be run on Athena and can be directly loaded from the file. Args: path (str, optional): The path to the pickle file. If not provided, the file will be saved on the current directory. trim_excess (bool, optional): If true, any queries in the cache that is not run in current session will be remved before saving it to file. This is useful if the cache has accumulated a bunch of stray queries over several sessions that are no longer used. Defaults to False. """ cached_queries = set(self._query_cache) if self.last_saved_queries == cached_queries: logger.info("No new queries to save.") return pickle_path = pathlib.Path(path) if path else self._get_cache_file_path() if trim_excess: if excess_queries := [key for key in self._query_cache if key not in self._session_queries]: for query in excess_queries: del self._query_cache[query] logger.info(f"{len(excess_queries)} excess queries removed from cache.") self.last_saved_queries = cached_queries save_pickle(pickle_path, self._query_cache) logger.info(f"{len(self._query_cache)} queries cache saved to {pickle_path}") def _initialize_tables(self): self.bs_table, self.ts_table, self.up_table = self._get_tables(self.table_name) self.bs_bldgid_column = self.bs_table.c[self.building_id_column_name] if self.ts_table is not None: self.timestamp_column = self.ts_table.c[self.timestamp_column_name] self.ts_bldgid_column = self.ts_table.c[self.building_id_column_name] if self.up_table is not None: self.up_bldgid_column = self.up_table.c[self.building_id_column_name] self.sample_wt = self._get_sample_weight(self.sample_weight) def _get_sample_weight(self, sample_weight): if not sample_weight: return sa.literal(1) elif isinstance(sample_weight, str): try: return self.bs_table.c[sample_weight] except ValueError: logger.error("Sample weight column not found. Using weight of 1.") return sa.literal(1) elif isinstance(sample_weight, (int, float)): return sa.literal(sample_weight) else: raise ValueError("Invalid value for sample_weight") @typing.overload def _get_table(self, table_name: AnyTableType, missing_ok: Literal[True]) -> Optional[sa.Table]: ... @typing.overload def _get_table(self, table_name: AnyTableType, missing_ok: Literal[False] = False) -> sa.Table: ... @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def _get_table(self, table_name: AnyTableType, missing_ok: bool = False) -> Optional[sa.Table]: if not isinstance(table_name, str): return table_name # already a table try: return self._tables.setdefault(table_name, sa.Table(table_name, self._meta, autoload_with=self._engine)) except sa.exc.NoSuchTableError: # type: ignore if missing_ok: logger.warning(f"No {table_name} table is present.") return None else: raise @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def _get_column(self, column_name: AnyColType, table_name: Optional[AnyTableType] = None) -> DBColType: if not isinstance(column_name, str): return column_name # already a col if table_name is not None: valid_tables = [self._get_table(table_name)] else: valid_tables = [] for tbl in [self.bs_table, self.up_table, self.ts_table]: if tbl is not None and column_name in tbl.columns: valid_tables.append(tbl) if not valid_tables: valid_tables += [table for _, table in self._tables.items() if column_name in table.columns] if not valid_tables: raise ValueError(f"Column {column_name} not found in any tables {[t.name for t in self._tables.values()]}") if len(valid_tables) > 1: logger.warning( f"Column {column_name} found in multiple tables {[t.name for t in valid_tables]}." f"Using {valid_tables[0].name}") return valid_tables[0].c[column_name] def _get_tables(self, table_name: Union[str, tuple[str, Optional[str], Optional[str]]]): self._engine = self._create_athena_engine(region_name=self.region_name, database=self.db_name, workgroup=self.workgroup) self._meta = sa.MetaData(bind=self._engine) if isinstance(table_name, str): baseline_table = self._get_table(f'{table_name}{self.db_schema.table_suffix.baseline}') ts_table = self._get_table(f'{table_name}{self.db_schema.table_suffix.timeseries}', missing_ok=True) if self.db_schema.table_suffix.upgrades == self.db_schema.table_suffix.baseline: upgrade_table = sa.select(baseline_table).where( sa.cast(baseline_table.c['upgrade'], sa.String) != '0').alias('upgrade') baseline_table = sa.select(baseline_table).where( sa.cast(baseline_table.c['upgrade'], sa.String) == '0').alias('baseline') else: upgrade_table = self._get_table(f'{table_name}{self.db_schema.table_suffix.upgrades}', missing_ok=True) else: baseline_table = self._get_table(f'{table_name[0]}') ts_table = self._get_table(f'{table_name[1]}', missing_ok=True) if table_name[1] else None if table_name[2] == table_name[0]: upgrade_table = sa.select(baseline_table).where( sa.cast(baseline_table.c['upgrade'], sa.String) != '0').alias('upgrade') baseline_table = sa.select(baseline_table).where( sa.cast(baseline_table.c['upgrade'], sa.String) == '0').alias('baseline') else: upgrade_table = self._get_table(f'{table_name[2]}', missing_ok=True) if table_name[2] else None return baseline_table, ts_table, upgrade_table def _initialize_book_keeping(self, execution_history): self._execution_history_file = execution_history or self.cache_folder / '.execution_history' self.execution_cost = {'GB': 0, 'Dollars': 0} # Tracks the cost of current session. Only used for Athena query self.seen_execution_ids = set() # set to prevent double counting same execution id self.last_saved_queries = set() if os.path.exists(self._execution_history_file): with open(self._execution_history_file, 'r') as f: existing_entries = f.readlines() valid_entries = [] for entry in existing_entries: with contextlib.suppress(ValueError, TypeError): entry_time, _ = entry.split(',') if time.time() - float(entry_time) < 24 * 60 * 60: # discard history if more than a day old valid_entries += entry with open(self._execution_history_file, 'w') as f: f.writelines(valid_entries) @property def _execution_ids_history(self): exe_ids: list[ExeId] = [] if os.path.exists(self._execution_history_file): with open(self._execution_history_file, 'r') as f: for line in f: _, exe_id = line.split(',') exe_ids.append(ExeId(exe_id.strip())) return exe_ids def _create_athena_engine(self, region_name: str, database: str, workgroup: str) -> sa.engine.Engine: connect_args = {"cursor_class": PandasCursor, "work_group": workgroup} engine = sa.create_engine( f"awsathena+rest://:@athena.{region_name}.amazonaws.com:443/{database}", connect_args=connect_args ) return engine @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def delete_table(self, table_name: str): """ Function to delete athena table. :param table_name: Athena table name :return: """ delete_table_query = f"""DROP TABLE {self.db_name}.{table_name};""" result, reason = self.execute_raw(delete_table_query) if result.upper() == "SUCCEEDED": return "SUCCEEDED" else: raise QueryException(f"Deleting it failed. Reason: {reason}") @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def add_table(self, table_name: str, table_df: pd.DataFrame, s3_bucket: str, s3_prefix: str, override: bool = False): """ Function to add a table in s3. :param table_name: The name of the table :param table_df: The pandas dataframe to use as table data :param s3_bucket: s3 bucket name :param s3_prefix: s3 prefix to save the table to. :param override: Whether to override eixsting table. :return: """ s3_location = s3_bucket + '/' + s3_prefix s3_data = self._aws_s3.list_objects(Bucket=s3_bucket, Prefix=f'{s3_prefix}/{table_name}') if 'Contents' in s3_data and override is False: raise DataExistsException("Table already exists", f's3://{s3_location}/{table_name}/{table_name}.csv') if 'Contents' in s3_data: existing_objects = [{'Key': el['Key']} for el in s3_data['Contents']] print(f"The following existing objects is being delete and replaced: {existing_objects}") print(f"Saving s3://{s3_location}/{table_name}/{table_name}.parquet)") self._aws_s3.delete_objects(Bucket=s3_bucket, Delete={"Objects": existing_objects}) print(f"Saving factors to s3 in s3://{s3_location}/{table_name}/{table_name}.parquet") # table_df.to_parquet(f's3://{s3_location}/{table_name}/{table_name}.parquet', index=False) self._aws_s3.put_object(Body=table_df.to_parquet(index=False), Bucket=s3_bucket, Key=f"{s3_prefix}/{table_name}/{table_name}.parquet") print("Saving Done.") format_list = [] for column_name, dtype in table_df.dtypes.items(): if np.issubdtype(dtype, np.integer): col_type = "int" elif np.issubdtype(dtype, np.floating): col_type = "double" elif np.issubdtype(dtype, np.datetime64): col_type = "timestamp" else: col_type = "string" format_list.append(f"`{column_name}` {col_type}") column_formats = ",".join(format_list) table_create_query = f""" CREATE EXTERNAL TABLE {self.db_name}.{table_name} ({column_formats}) STORED AS PARQUET LOCATION 's3://{s3_location}/{table_name}/' TBLPROPERTIES ('has_encrypted_data'='false'); """ print(f"Running create table query.\n {table_create_query}") result, reason = self.execute_raw(table_create_query) if result.lower() == "failed" and 'alreadyexists' in reason.lower(): if not override: existing_data = read_csv(f's3://{s3_location}/{table_name}/{table_name}.csv') raise DataExistsException("Table already exists", existing_data) print(f"There was existing table {table_name} in Athena which was deleted and recreated.") delete_table_query = f""" DROP TABLE {self.db_name}.{table_name}; """ result, reason = self.execute_raw(delete_table_query) if result.upper() != "SUCCEEDED": raise QueryException(f"There was an existing table named {table_name}. Deleting it failed." f" Reason: {reason}") result, reason = self.execute_raw(table_create_query) if result.upper() == "SUCCEEDED": return "SUCCEEDED" else: raise QueryException(f"There was an existing table named {table_name} which is now successfully " f"deleted but new table failed to be created. Reason: {reason}") elif result.upper() == "SUCCEEDED": return "SUCCEEDED" else: raise QueryException(f"Failed to create the table. Reason: {reason}") @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def execute_raw(self, query, db: Optional[str] = None, run_async: bool = False): """ Directly executes the supplied query in Athena. :param query: :param db: :param run_async: :return: """ if not db: db = self.db_name response = self._aws_athena.start_query_execution( QueryString=query, QueryExecutionContext={ 'Database': db }, WorkGroup=self.workgroup) query_execution_id = ExeId(response['QueryExecutionId']) if run_async: return query_execution_id start_time = time.time() while time.time() - start_time < 30*60: # 30 minute timeout query_stat = self._aws_athena.get_query_execution(QueryExecutionId=query_execution_id) if query_stat['QueryExecution']['Status']['State'].lower() not in ['pending', 'running', 'queued']: reason = query_stat['QueryExecution']['Status'].get('StateChangeReason', '') return query_stat['QueryExecution']['Status']['State'], reason time.sleep(1) raise TimeoutError("Query failed to complete within 30 mins.") def _save_execution_id(self, execution_id): with open(self._execution_history_file, 'a') as f: f.write(f'{time.time()},{execution_id}\n') def _log_execution_cost(self, execution_id: ExeId): if execution_id == "CACHED": # Can't log cost for cached query return res = self._aws_athena.get_query_execution(QueryExecutionId=execution_id) scanned_GB = res['QueryExecution']['Statistics']['DataScannedInBytes'] / 1e9 cost = scanned_GB * 5 / 1e3 # 5$ per TB scanned if execution_id not in self.seen_execution_ids: self.execution_cost['Dollars'] += cost self.execution_cost['GB'] += scanned_GB self.seen_execution_ids.add(execution_id) logger.info(f"{execution_id} cost {scanned_GB:.1f} GB (${cost:.1f}). Session total:" f" {self.execution_cost['GB']:.1f} GB (${self.execution_cost['Dollars']:.1f})") def _compile(self, query) -> str: compiled_query = CustomCompiler(AthenaDialect(), query).process(query, literal_binds=True) return compiled_query @typing.overload def execute(self, query, *, run_async: Literal[False] = False) -> pd.DataFrame: ... @typing.overload def execute(self, query, *, run_async: Literal[True], ) -> Union[tuple[Literal["CACHED"], CachedFutureDf], tuple[ExeId, AthenaFutureDf]]: ... @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def execute(self, query, run_async: bool = False) -> Union[pd.DataFrame, tuple[Literal["CACHED"], CachedFutureDf], tuple[ExeId, AthenaFutureDf]]: """ Executes a query Args: query: The SQL query to run in Athena run_async: Whether to wait until the query completes (run_async=False) or return immediately (run_async=True). Returns: if run_async is False, returns the results dataframe. if run_async is True, returns the query_execution_id, futures """ if not isinstance(query, str): query = self._compile(query) self._session_queries.add(query) if run_async: if query in self._query_cache: return "CACHED", CachedFutureDf(self._query_cache[query].copy()) # in case of asynchronous run, you get the execution id and futures object exe_id, result_future = self._async_conn.cursor().execute(query, result_reuse_enable=self.athena_query_reuse, result_reuse_minutes=60 * 24 * 7, na_values=['']) # type: ignore exe_id = ExeId(exe_id) def get_pandas(future): res = future.result() if res.state != 'SUCCEEDED': raise OperationalError(f"{res.state}: {res.state_change_reason}") if query in self._query_cache: return self._query_cache[query] return res.as_pandas() result_future.as_pandas = types.MethodType(get_pandas, result_future) result_future.add_done_callback(lambda f: self._query_cache.update({query: f.as_pandas()})) self._save_execution_id(exe_id) return exe_id, AthenaFutureDf(result_future) else: if query not in self._query_cache: self._query_cache[query] = self._conn.cursor().execute(query, result_reuse_enable=self.athena_query_reuse, result_reuse_minutes=60 * 24 * 7, ).as_pandas() return self._query_cache[query].copy() def print_all_batch_query_status(self) -> None: """Prints the status of all batch queries. """ for count in self._batch_query_status_map.keys(): print(f'Query {count}: {self.get_batch_query_report(count)}\n') @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def stop_batch_query(self, batch_id: int) -> None: """ Stops all the queries running under a batch query Args: batch_id: The batch_id of the batch_query. Returned by :py:sumbit_batch_query Returns: None """ if batch_id not in self._batch_query_status_map: raise ValueError("Batch id not found") self._batch_query_status_map[batch_id]['to_submit_ids'].clear() for exec_id in self._batch_query_status_map[batch_id]['submitted_execution_ids']: self.stop_query(exec_id) @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def get_failed_queries(self, batch_id: int) -> tuple[Sequence[ExeId], Sequence[str]]: """_summary_ Args: batch_id (int): Batch query id returned by :py:sumbit_batch_query Returns: _type_: tuple of list of failed query execution ids and list of failed queries """ stats = self._batch_query_status_map.get(batch_id, None) failed_query_ids: list[ExeId] = [] failed_queries: list[str] = [] if stats: for i, exe_id in enumerate(stats['submitted_execution_ids']): completion_stat = self.get_query_status(exe_id) if completion_stat in ['FAILED', 'CANCELLED']: failed_query_ids.append(exe_id) failed_queries.append(stats['submitted_queries'][i]) return failed_query_ids, failed_queries @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def print_failed_query_errors(self, batch_id: int) -> None: """Print the error messages for all queries that failed in batch query. Args: batch_id (int): Batch query id """ failed_ids, failed_queries = self.get_failed_queries(batch_id) for exe_id, query in zip(failed_ids, failed_queries): print(f"Query id: {exe_id}. \n Query string: {query}. Query Ended with: {self.get_query_status(exe_id)}" f"\nError: {self.get_query_error(exe_id)}\n") @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def get_ids_for_failed_queries(self, batch_id: int) -> Sequence[str]: """Returns the list of execution ids for failed queries in batch query. Args: batch_id (int): batch query id Returns: Sequence[str]: List of failed execution ids. """ failed_ids = [] for i, exe_id in enumerate(self._batch_query_status_map[batch_id]['submitted_execution_ids']): completion_stat = self.get_query_status(exe_id) if completion_stat in ['FAILED', 'CANCELLED']: failed_ids.append(exe_id) return failed_ids @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def get_batch_query_report(self, batch_id: int) -> BatchQueryReportMap: """ Returns the status of the queries running under a batch query. Args: batch_id: The batch_id of the batch_query. Returns: A dictionary detailing status of the queries. """ if not (stats := self._batch_query_status_map.get(batch_id, None)): raise ValueError(f"{batch_id=} not found.") success_count = 0 fail_count = 0 running_count = 0 other = 0 for exe_id in stats['submitted_execution_ids']: if exe_id == 'CACHED': completion_stat = "SUCCEEDED" else: completion_stat = self.get_query_status(exe_id) if completion_stat == 'RUNNING': running_count += 1 elif completion_stat == 'SUCCEEDED': success_count += 1 elif completion_stat in ['FAILED', 'CANCELLED']: fail_count += 1 else: # for example: QUEUED other += 1 result: BatchQueryReportMap = {'submitted': len(stats['submitted_ids']), 'running': running_count, 'pending': len(stats['to_submit_ids']) + other, 'completed': success_count, 'failed': fail_count } return result @validate_arguments def did_batch_query_complete(self, batch_id: int): """ Checks if all the queries in a batch query has completed or not. Args: batch_id: The batch_id for the batch_query Returns: True or False """ status = self.get_batch_query_report(batch_id) if status['pending'] > 0 or status['running'] > 0: return False else: return True @validate_arguments def wait_for_batch_query(self, batch_id: int): """Waits until batch query completes. Args: batch_id (int): The batch query id. """ sleep_time = 0.5 # start here and keep doubling until max_sleep_time max_sleep_time = 20 while True: last_time = time.time() last_report = None report = self.get_batch_query_report(batch_id) if time.time() - last_time > 60 or last_report is None or report != last_report: logger.info(report) last_report = report last_time = time.time() if report['pending'] == 0 and report['running'] == 0: break time.sleep(sleep_time) sleep_time = min(sleep_time * 2, max_sleep_time) @typing.overload def get_batch_query_result(self, batch_id: int, *, no_block: bool = False, combine: Literal[True] = True) -> pd.DataFrame: ... @typing.overload def get_batch_query_result(self, batch_id: int, *, no_block: bool = False, combine: Literal[False]) -> list[pd.DataFrame]: ... @validate_arguments def get_batch_query_result(self, batch_id: int, *, combine: bool = True, no_block: bool = False): """ Concatenates and returns the results of all the queries of a batchquery Args: batch_id (int): The batch_id for the batch_query no_block (bool): Whether to wait until all queries have completed or return immediately. If you use no_block = true and the batch hasn't completed, it will throw BatchStillRunning exception. combine: Whether to combine the individual query result into a single dataframe Returns: The concatenated dataframe of the results of all the queries in a batch query. """ if no_block and self.did_batch_query_complete(batch_id) is False: raise QueryException('Batch query not completed yet.') self.wait_for_batch_query(batch_id) logger.info("Batch query completed. ") report = self.get_batch_query_report(batch_id) query_exe_ids = self._batch_query_status_map[batch_id]['submitted_execution_ids'] query_futures = self._batch_query_status_map[batch_id]['queries_futures'] if report['failed'] > 0: logger.warning(f"{report['failed']} queries failed. Redoing them") failed_ids, failed_queries = self.get_failed_queries(batch_id) new_batch_id = self.submit_batch_query(failed_queries) new_exe_ids = self._batch_query_status_map[new_batch_id]['submitted_execution_ids'] self.wait_for_batch_query(new_batch_id) new_exe_ids_map = {entry[0]: entry[1] for entry in zip(failed_ids, new_exe_ids)} new_report = self.get_batch_query_report(new_batch_id) if new_report['failed'] > 0: self.print_failed_query_errors(new_batch_id) raise QueryException("Queries failed again. Sorry!") logger.info("The queries succeeded this time. Gathering all the results.") # replace the old failed exe_ids with new successful exe_ids for indx, old_exe_id in enumerate(query_exe_ids): query_exe_ids[indx] = new_exe_ids_map.get(old_exe_id, old_exe_id) if len(query_exe_ids) == 0: raise ValueError("No query was submitted successfully") res_df_array: list[pd.DataFrame] = [] for index, exe_id in enumerate(query_exe_ids): df = query_futures[index].as_pandas().copy() if combine: if len(df) > 0: df['query_id'] = index logger.info(f"Got result from Query [{index}] ({exe_id})") self._log_execution_cost(exe_id) res_df_array.append(df) if not combine: return res_df_array logger.info("Concatenating the results.") # return res_df_array return pd.concat(res_df_array) @validate_arguments def submit_batch_query(self, queries: Sequence[str]): """ Submit multiple related queries Args: queries: List of queries to submit. Setting `get_query_only` flag while making calls to aggregation functions is easiest way to obtain queries. Returns: An integer representing the batch_query id. The id can be used with other batch_query functions. """ queries = list(queries) to_submit_ids = list(range(len(queries))) id_list = list(to_submit_ids) # make a copy submitted_ids: list[int] = [] submitted_execution_ids: list[ExeId] = [] submitted_queries: list[str] = [] queries_futures: list = [] self._batch_query_id += 1 batch_query_id = self._batch_query_id self._batch_query_status_map[batch_query_id] = {'to_submit_ids': to_submit_ids, 'all_ids': list(id_list), 'submitted_ids': submitted_ids, 'submitted_execution_ids': submitted_execution_ids, 'submitted_queries': submitted_queries, 'queries_futures': queries_futures } def run_queries(): while to_submit_ids: current_id = to_submit_ids[0] # get the first one current_query = queries[0] try: execution_id, future = self.execute(current_query, run_async=True) logger.info(f"Submitted queries[{current_id}] ({execution_id})") to_submit_ids.pop(0) # if query queued successfully, remove it from the list queries.pop(0) submitted_ids.append(current_id) submitted_execution_ids.append(ExeId(execution_id)) submitted_queries.append(current_query) queries_futures.append(future) except ClientError as e: if e.response['Error']['Code'] == 'TooManyRequestsException': logger.info("Athena complained about too many requests. Waiting for a minute.") time.sleep(60) # wait for a minute before submitting another query elif e.response['Error']['Code'] == 'InvalidRequestException': logger.info(f"Queries[{current_id}] is Invalid: {e.response['Message']} \n {current_query}") to_submit_ids.pop(0) # query failed, so remove it from the list queries.pop(0) raise else: raise query_runner = Thread(target=run_queries) query_runner.start() return batch_query_id def _get_query_result(self, query_id): return self.get_athena_query_result(execution_id=query_id) @validate_arguments def get_athena_query_result(self, execution_id: ExeId, timeout_minutes: int = 30) -> pd.DataFrame: """Returns the query result Args: execution_id (str): Query execution id. timeout_minutes (int, optional): Timeout in minutes to wait for query to finish. Defaults to 30. Raises: QueryException: If query fails for some reason. Returns: pd.DataFrame: Query result as dataframe. """ t = time.time() while time.time() - t < timeout_minutes * 60: stat = self.get_query_status(execution_id) if stat.upper() == 'SUCCEEDED': result = self.get_result_from_s3(execution_id) self._log_execution_cost(execution_id) return result elif stat.upper() == 'FAILED': error = self.get_query_error(execution_id) raise QueryException(error) else: logger.info(f"Query status is {stat}") time.sleep(30) raise QueryException(f'Query timed-out. {self.get_query_status(execution_id)}') @validate_arguments def get_result_from_s3(self, query_execution_id: ExeId) -> pd.DataFrame: """Returns query result from s3 location. Args: query_execution_id (str): The query execution ID Raises: QueryException: If query had failed. Returns: pd.DataFrame: The query result. """ query_status = self.get_query_status(query_execution_id) if query_status == 'SUCCEEDED': path = self.get_query_output_location(query_execution_id) bucket = path.split('/')[2] key = '/'.join(path.split('/')[3:]) response = self._aws_s3.get_object(Bucket=bucket, Key=key) df = read_csv(response['Body']) return df # If failed, return error message elif query_status == 'FAILED': raise QueryException(self.get_query_error(query_execution_id)) elif query_status in ['RUNNING', 'QUEUED', 'PENDING']: raise QueryException(f"Query still {query_status}") else: raise QueryException(f"Query has unkown status {query_status}") @validate_arguments def get_query_output_location(self, query_id: ExeId) -> str: """Get query output location in s3. Args: query_id (str): Query execution id. Returns: str: The query location in s3. """ stat = self._aws_athena.get_query_execution(QueryExecutionId=query_id) output_path = stat['QueryExecution']['ResultConfiguration']['OutputLocation'] return output_path @validate_arguments def get_query_status(self, query_id: ExeId) -> str: """Get status of the query Args: query_id (str): Query execution id Returns: str: Status of the query. """ stat = self._aws_athena.get_query_execution(QueryExecutionId=query_id) return stat['QueryExecution']['Status']['State'] @validate_arguments def get_query_error(self, query_id: ExeId) -> str: """Returns the error message if query has failed. Args: query_id (str): Query execution id. Returns: str: Error message for the query. """ stat = self._aws_athena.get_query_execution(QueryExecutionId=query_id) return stat['QueryExecution']['Status']['StateChangeReason'] def get_all_running_queries(self) -> list[ExeId]: """ Gives the list of all running queries (for this instance) Return: List of query execution ids of all the queries that are currently running in Athena. """ exe_ids = self._aws_athena.list_query_executions(WorkGroup=self.workgroup)['QueryExecutionIds'] exe_ids = [ExeId(i) for i in exe_ids] running_ids = [i for i in exe_ids if i in self._execution_ids_history and self.get_query_status(i) == "RUNNING"] return running_ids def stop_all_queries(self) -> None: """ Stops all queries that are running in Athena for this instance. Returns: Nothing """ for count, stat in self._batch_query_status_map.items(): stat['to_submit_ids'].clear() running_ids = self.get_all_running_queries() for i in running_ids: self.stop_query(execution_id=i) logger.info(f"Stopped {len(running_ids)} queries") @validate_arguments def stop_query(self, execution_id: ExeId) -> str: """ Stops a running query. Args: execution_id: The execution id of the query being run. Returns: """ return self._aws_athena.stop_query_execution(QueryExecutionId=execution_id) @validate_arguments(config=dict(arbitrary_types_allowed=True)) def get_cols(self, table: AnyTableType, fuel_type=None) -> Sequence[DBColType]: """ Returns the columns of for a particular table. Args: table: Name of the table. One of 'baseline' or 'timeseries' fuel_type: Get only the columns for this fuel_type ('electricity', 'gas' etc) Returns: A list of column names as a list of strings. """ table = self._get_table(table) if table == self.ts_table and self.ts_table is not None: cols = [c for c in self.ts_table.columns] if fuel_type: cols = [c for c in cols if c.name not in [self.ts_bldgid_column.name, self.timestamp_column.name]] cols = [c for c in cols if fuel_type in c.name] return cols elif table in ['baseline', 'bs']: cols = [c for c in self.bs_table.columns] if fuel_type: cols = [c for c in cols if 'simulation_output_report' in c.name] cols = [c for c in cols if fuel_type in c.name] return cols else: tbl = self._get_table(table) return [col for col in tbl.columns] def _simple_label(self, label: str): label = label.removeprefix(self.db_schema.column_prefix.characteristics) label = label.removeprefix(self.db_schema.column_prefix.output) return label def _add_restrict(self, query, restrict, bs_only=False): if not restrict: return query where_clauses = [] for col_str, criteria in restrict: col = self._get_column(col_str, table_name=self.bs_table) if bs_only else self._get_column(col_str) if isinstance(criteria, (list, tuple)): if len(criteria) > 1: where_clauses.append(self._get_column(col).in_(criteria)) continue else: criteria = criteria[0] where_clauses.append(col == criteria) query = query.where(*where_clauses) return query def _get_name(self, col): if isinstance(col, tuple): return col[1] if isinstance(col, str): return col if isinstance(col, (sa.Column, SALabel)): return col.name raise ValueError(f"Can't get name for {col} of type {type(col)}") def _add_join(self, query, join_list): for new_table_name, baseline_column_name, new_column_name in join_list: new_tbl = self._get_table(new_table_name) baseline_column = self._get_column(baseline_column_name, table_name=self.bs_table) new_column = self._get_column(new_column_name, table_name=new_tbl) query = query.join(new_tbl, baseline_column == new_column) return query def _add_group_by(self, query, group_by_selection): if group_by_selection: selected_cols = list(query.selected_columns) a = [sa.text(str(selected_cols.index(g) + 1)) for g in group_by_selection] query = query.group_by(*a) return query def _add_order_by(self, query, order_by_selection): if order_by_selection: selected_cols = list(query.selected_columns) a = [sa.text(str(selected_cols.index(g) + 1)) for g in order_by_selection] query = query.order_by(*a) return query def _get_weight(self, weights): total_weight = self.sample_wt for weight_col in weights: if isinstance(weight_col, tuple): tbl = self._get_table(weight_col[1]) total_weight *= tbl.c[weight_col[0]] else: total_weight *= self._get_column(weight_col) return total_weight def delete_everything(self): """Deletes the athena tables and data in s3 for the run. """ info = self._aws_glue.get_table(DatabaseName=self.db_name, Name=self.bs_table.name) self.pth = pathlib.Path(info['Table']['StorageDescriptor']['Location']).parent tables_to_delete = [self.bs_table.name] if self.ts_table is not None: tables_to_delete.append(self.ts_table.name) if self.up_table is not None: tables_to_delete.append(self.up_table.name) print(f"Will delete the following tables {tables_to_delete} and the {self.pth} folder") while True: curtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M') confirm = input(f"Enter {curtime} to confirm.") if confirm == "": print("Abandoned the idea.") break if confirm != curtime: print(f"Please pass {curtime} as confirmation to confirm you want to delete everything.") continue self._aws_glue.batch_delete_table(DatabaseName=self.db_name, TablesToDelete=tables_to_delete) print("Deleted the table from athena, now will delete the data in s3") s3 = boto3.resource('s3') bucket = s3.Bucket(self.pth.parts[1]) # type: ignore prefix = str(pathlib.Path(*self.pth.parts[2:])) total_files = [file.key for file in bucket.objects.filter(Prefix=prefix)] print(f"There are {len(total_files)} files to be deleted. Deleting them now") bucket.objects.filter(Prefix=prefix).delete() print("Delete from s3 completed") break
Subclasses
Methods
def add_table(self, table_name: str, table_df: pandas.core.frame.DataFrame, s3_bucket: str, s3_prefix: str, override: bool = False)
-
Function to add a table in s3. :param table_name: The name of the table :param table_df: The pandas dataframe to use as table data :param s3_bucket: s3 bucket name :param s3_prefix: s3 prefix to save the table to. :param override: Whether to override eixsting table. :return:
Expand source code
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def add_table(self, table_name: str, table_df: pd.DataFrame, s3_bucket: str, s3_prefix: str, override: bool = False): """ Function to add a table in s3. :param table_name: The name of the table :param table_df: The pandas dataframe to use as table data :param s3_bucket: s3 bucket name :param s3_prefix: s3 prefix to save the table to. :param override: Whether to override eixsting table. :return: """ s3_location = s3_bucket + '/' + s3_prefix s3_data = self._aws_s3.list_objects(Bucket=s3_bucket, Prefix=f'{s3_prefix}/{table_name}') if 'Contents' in s3_data and override is False: raise DataExistsException("Table already exists", f's3://{s3_location}/{table_name}/{table_name}.csv') if 'Contents' in s3_data: existing_objects = [{'Key': el['Key']} for el in s3_data['Contents']] print(f"The following existing objects is being delete and replaced: {existing_objects}") print(f"Saving s3://{s3_location}/{table_name}/{table_name}.parquet)") self._aws_s3.delete_objects(Bucket=s3_bucket, Delete={"Objects": existing_objects}) print(f"Saving factors to s3 in s3://{s3_location}/{table_name}/{table_name}.parquet") # table_df.to_parquet(f's3://{s3_location}/{table_name}/{table_name}.parquet', index=False) self._aws_s3.put_object(Body=table_df.to_parquet(index=False), Bucket=s3_bucket, Key=f"{s3_prefix}/{table_name}/{table_name}.parquet") print("Saving Done.") format_list = [] for column_name, dtype in table_df.dtypes.items(): if np.issubdtype(dtype, np.integer): col_type = "int" elif np.issubdtype(dtype, np.floating): col_type = "double" elif np.issubdtype(dtype, np.datetime64): col_type = "timestamp" else: col_type = "string" format_list.append(f"`{column_name}` {col_type}") column_formats = ",".join(format_list) table_create_query = f""" CREATE EXTERNAL TABLE {self.db_name}.{table_name} ({column_formats}) STORED AS PARQUET LOCATION 's3://{s3_location}/{table_name}/' TBLPROPERTIES ('has_encrypted_data'='false'); """ print(f"Running create table query.\n {table_create_query}") result, reason = self.execute_raw(table_create_query) if result.lower() == "failed" and 'alreadyexists' in reason.lower(): if not override: existing_data = read_csv(f's3://{s3_location}/{table_name}/{table_name}.csv') raise DataExistsException("Table already exists", existing_data) print(f"There was existing table {table_name} in Athena which was deleted and recreated.") delete_table_query = f""" DROP TABLE {self.db_name}.{table_name}; """ result, reason = self.execute_raw(delete_table_query) if result.upper() != "SUCCEEDED": raise QueryException(f"There was an existing table named {table_name}. Deleting it failed." f" Reason: {reason}") result, reason = self.execute_raw(table_create_query) if result.upper() == "SUCCEEDED": return "SUCCEEDED" else: raise QueryException(f"There was an existing table named {table_name} which is now successfully " f"deleted but new table failed to be created. Reason: {reason}") elif result.upper() == "SUCCEEDED": return "SUCCEEDED" else: raise QueryException(f"Failed to create the table. Reason: {reason}")
def delete_everything(self)
-
Deletes the athena tables and data in s3 for the run.
Expand source code
def delete_everything(self): """Deletes the athena tables and data in s3 for the run. """ info = self._aws_glue.get_table(DatabaseName=self.db_name, Name=self.bs_table.name) self.pth = pathlib.Path(info['Table']['StorageDescriptor']['Location']).parent tables_to_delete = [self.bs_table.name] if self.ts_table is not None: tables_to_delete.append(self.ts_table.name) if self.up_table is not None: tables_to_delete.append(self.up_table.name) print(f"Will delete the following tables {tables_to_delete} and the {self.pth} folder") while True: curtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M') confirm = input(f"Enter {curtime} to confirm.") if confirm == "": print("Abandoned the idea.") break if confirm != curtime: print(f"Please pass {curtime} as confirmation to confirm you want to delete everything.") continue self._aws_glue.batch_delete_table(DatabaseName=self.db_name, TablesToDelete=tables_to_delete) print("Deleted the table from athena, now will delete the data in s3") s3 = boto3.resource('s3') bucket = s3.Bucket(self.pth.parts[1]) # type: ignore prefix = str(pathlib.Path(*self.pth.parts[2:])) total_files = [file.key for file in bucket.objects.filter(Prefix=prefix)] print(f"There are {len(total_files)} files to be deleted. Deleting them now") bucket.objects.filter(Prefix=prefix).delete() print("Delete from s3 completed") break
def delete_table(self, table_name: str)
-
Function to delete athena table. :param table_name: Athena table name :return:
Expand source code
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def delete_table(self, table_name: str): """ Function to delete athena table. :param table_name: Athena table name :return: """ delete_table_query = f"""DROP TABLE {self.db_name}.{table_name};""" result, reason = self.execute_raw(delete_table_query) if result.upper() == "SUCCEEDED": return "SUCCEEDED" else: raise QueryException(f"Deleting it failed. Reason: {reason}")
def did_batch_query_complete(self, batch_id: int)
-
Checks if all the queries in a batch query has completed or not.
Args
batch_id
- The batch_id for the batch_query
Returns
True or False
Expand source code
@validate_arguments def did_batch_query_complete(self, batch_id: int): """ Checks if all the queries in a batch query has completed or not. Args: batch_id: The batch_id for the batch_query Returns: True or False """ status = self.get_batch_query_report(batch_id) if status['pending'] > 0 or status['running'] > 0: return False else: return True
def execute(self, query, run_async: bool = False) ‑> Union[pandas.core.frame.DataFrame, tuple[Literal['CACHED'], CachedFutureDf], tuple[buildstock_query.query_core.ExeId, AthenaFutureDf]]
-
Executes a query
Args
query
- The SQL query to run in Athena
run_async
- Whether to wait until the query completes (run_async=False) or return immediately
(run_async=True).
Returns
if run_async is False, returns the results dataframe. if run_async is True, returns the query_execution_id, futures
Expand source code
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def execute(self, query, run_async: bool = False) -> Union[pd.DataFrame, tuple[Literal["CACHED"], CachedFutureDf], tuple[ExeId, AthenaFutureDf]]: """ Executes a query Args: query: The SQL query to run in Athena run_async: Whether to wait until the query completes (run_async=False) or return immediately (run_async=True). Returns: if run_async is False, returns the results dataframe. if run_async is True, returns the query_execution_id, futures """ if not isinstance(query, str): query = self._compile(query) self._session_queries.add(query) if run_async: if query in self._query_cache: return "CACHED", CachedFutureDf(self._query_cache[query].copy()) # in case of asynchronous run, you get the execution id and futures object exe_id, result_future = self._async_conn.cursor().execute(query, result_reuse_enable=self.athena_query_reuse, result_reuse_minutes=60 * 24 * 7, na_values=['']) # type: ignore exe_id = ExeId(exe_id) def get_pandas(future): res = future.result() if res.state != 'SUCCEEDED': raise OperationalError(f"{res.state}: {res.state_change_reason}") if query in self._query_cache: return self._query_cache[query] return res.as_pandas() result_future.as_pandas = types.MethodType(get_pandas, result_future) result_future.add_done_callback(lambda f: self._query_cache.update({query: f.as_pandas()})) self._save_execution_id(exe_id) return exe_id, AthenaFutureDf(result_future) else: if query not in self._query_cache: self._query_cache[query] = self._conn.cursor().execute(query, result_reuse_enable=self.athena_query_reuse, result_reuse_minutes=60 * 24 * 7, ).as_pandas() return self._query_cache[query].copy()
def execute_raw(self, query, db: Optional[str] = None, run_async: bool = False)
-
Directly executes the supplied query in Athena. :param query: :param db: :param run_async: :return:
Expand source code
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def execute_raw(self, query, db: Optional[str] = None, run_async: bool = False): """ Directly executes the supplied query in Athena. :param query: :param db: :param run_async: :return: """ if not db: db = self.db_name response = self._aws_athena.start_query_execution( QueryString=query, QueryExecutionContext={ 'Database': db }, WorkGroup=self.workgroup) query_execution_id = ExeId(response['QueryExecutionId']) if run_async: return query_execution_id start_time = time.time() while time.time() - start_time < 30*60: # 30 minute timeout query_stat = self._aws_athena.get_query_execution(QueryExecutionId=query_execution_id) if query_stat['QueryExecution']['Status']['State'].lower() not in ['pending', 'running', 'queued']: reason = query_stat['QueryExecution']['Status'].get('StateChangeReason', '') return query_stat['QueryExecution']['Status']['State'], reason time.sleep(1) raise TimeoutError("Query failed to complete within 30 mins.")
def get_all_running_queries(self) ‑> list[buildstock_query.query_core.ExeId]
-
Gives the list of all running queries (for this instance)
Return
List of query execution ids of all the queries that are currently running in Athena.
Expand source code
def get_all_running_queries(self) -> list[ExeId]: """ Gives the list of all running queries (for this instance) Return: List of query execution ids of all the queries that are currently running in Athena. """ exe_ids = self._aws_athena.list_query_executions(WorkGroup=self.workgroup)['QueryExecutionIds'] exe_ids = [ExeId(i) for i in exe_ids] running_ids = [i for i in exe_ids if i in self._execution_ids_history and self.get_query_status(i) == "RUNNING"] return running_ids
def get_athena_query_result(self, execution_id: buildstock_query.query_core.ExeId, timeout_minutes: int = 30) ‑> pandas.core.frame.DataFrame
-
Returns the query result
Args
execution_id
:str
- Query execution id.
timeout_minutes
:int
, optional- Timeout in minutes to wait for query to finish. Defaults to 30.
Raises
QueryException
- If query fails for some reason.
Returns
pd.DataFrame
- Query result as dataframe.
Expand source code
@validate_arguments def get_athena_query_result(self, execution_id: ExeId, timeout_minutes: int = 30) -> pd.DataFrame: """Returns the query result Args: execution_id (str): Query execution id. timeout_minutes (int, optional): Timeout in minutes to wait for query to finish. Defaults to 30. Raises: QueryException: If query fails for some reason. Returns: pd.DataFrame: Query result as dataframe. """ t = time.time() while time.time() - t < timeout_minutes * 60: stat = self.get_query_status(execution_id) if stat.upper() == 'SUCCEEDED': result = self.get_result_from_s3(execution_id) self._log_execution_cost(execution_id) return result elif stat.upper() == 'FAILED': error = self.get_query_error(execution_id) raise QueryException(error) else: logger.info(f"Query status is {stat}") time.sleep(30) raise QueryException(f'Query timed-out. {self.get_query_status(execution_id)}')
def get_batch_query_report(self, batch_id: int) ‑> BatchQueryReportMap
-
Returns the status of the queries running under a batch query.
Args
batch_id
- The batch_id of the batch_query.
Returns
A dictionary detailing status of the queries.
Expand source code
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def get_batch_query_report(self, batch_id: int) -> BatchQueryReportMap: """ Returns the status of the queries running under a batch query. Args: batch_id: The batch_id of the batch_query. Returns: A dictionary detailing status of the queries. """ if not (stats := self._batch_query_status_map.get(batch_id, None)): raise ValueError(f"{batch_id=} not found.") success_count = 0 fail_count = 0 running_count = 0 other = 0 for exe_id in stats['submitted_execution_ids']: if exe_id == 'CACHED': completion_stat = "SUCCEEDED" else: completion_stat = self.get_query_status(exe_id) if completion_stat == 'RUNNING': running_count += 1 elif completion_stat == 'SUCCEEDED': success_count += 1 elif completion_stat in ['FAILED', 'CANCELLED']: fail_count += 1 else: # for example: QUEUED other += 1 result: BatchQueryReportMap = {'submitted': len(stats['submitted_ids']), 'running': running_count, 'pending': len(stats['to_submit_ids']) + other, 'completed': success_count, 'failed': fail_count } return result
def get_batch_query_result(self, batch_id: int, *, combine: bool = True, no_block: bool = False)
-
Concatenates and returns the results of all the queries of a batchquery
Args
batch_id
:int
- The batch_id for the batch_query
no_block
:bool
- Whether to wait until all queries have completed or return immediately. If you use no_block = true and the batch hasn't completed, it will throw BatchStillRunning exception.
combine
- Whether to combine the individual query result into a single dataframe
Returns
The concatenated dataframe of the results of all the queries in a batch query.
Expand source code
@validate_arguments def get_batch_query_result(self, batch_id: int, *, combine: bool = True, no_block: bool = False): """ Concatenates and returns the results of all the queries of a batchquery Args: batch_id (int): The batch_id for the batch_query no_block (bool): Whether to wait until all queries have completed or return immediately. If you use no_block = true and the batch hasn't completed, it will throw BatchStillRunning exception. combine: Whether to combine the individual query result into a single dataframe Returns: The concatenated dataframe of the results of all the queries in a batch query. """ if no_block and self.did_batch_query_complete(batch_id) is False: raise QueryException('Batch query not completed yet.') self.wait_for_batch_query(batch_id) logger.info("Batch query completed. ") report = self.get_batch_query_report(batch_id) query_exe_ids = self._batch_query_status_map[batch_id]['submitted_execution_ids'] query_futures = self._batch_query_status_map[batch_id]['queries_futures'] if report['failed'] > 0: logger.warning(f"{report['failed']} queries failed. Redoing them") failed_ids, failed_queries = self.get_failed_queries(batch_id) new_batch_id = self.submit_batch_query(failed_queries) new_exe_ids = self._batch_query_status_map[new_batch_id]['submitted_execution_ids'] self.wait_for_batch_query(new_batch_id) new_exe_ids_map = {entry[0]: entry[1] for entry in zip(failed_ids, new_exe_ids)} new_report = self.get_batch_query_report(new_batch_id) if new_report['failed'] > 0: self.print_failed_query_errors(new_batch_id) raise QueryException("Queries failed again. Sorry!") logger.info("The queries succeeded this time. Gathering all the results.") # replace the old failed exe_ids with new successful exe_ids for indx, old_exe_id in enumerate(query_exe_ids): query_exe_ids[indx] = new_exe_ids_map.get(old_exe_id, old_exe_id) if len(query_exe_ids) == 0: raise ValueError("No query was submitted successfully") res_df_array: list[pd.DataFrame] = [] for index, exe_id in enumerate(query_exe_ids): df = query_futures[index].as_pandas().copy() if combine: if len(df) > 0: df['query_id'] = index logger.info(f"Got result from Query [{index}] ({exe_id})") self._log_execution_cost(exe_id) res_df_array.append(df) if not combine: return res_df_array logger.info("Concatenating the results.") # return res_df_array return pd.concat(res_df_array)
def get_cols(self, table: Union[sqlalchemy.sql.schema.Table, str, sqlalchemy.sql.selectable.Subquery], fuel_type=None) ‑> Sequence[Union[sqlalchemy.sql.elements.Label, sqlalchemy.sql.schema.Column]]
-
Returns the columns of for a particular table.
Args
table
- Name of the table. One of 'baseline' or 'timeseries'
fuel_type
- Get only the columns for this fuel_type ('electricity', 'gas' etc)
Returns
A list of column names as a list of strings.
Expand source code
@validate_arguments(config=dict(arbitrary_types_allowed=True)) def get_cols(self, table: AnyTableType, fuel_type=None) -> Sequence[DBColType]: """ Returns the columns of for a particular table. Args: table: Name of the table. One of 'baseline' or 'timeseries' fuel_type: Get only the columns for this fuel_type ('electricity', 'gas' etc) Returns: A list of column names as a list of strings. """ table = self._get_table(table) if table == self.ts_table and self.ts_table is not None: cols = [c for c in self.ts_table.columns] if fuel_type: cols = [c for c in cols if c.name not in [self.ts_bldgid_column.name, self.timestamp_column.name]] cols = [c for c in cols if fuel_type in c.name] return cols elif table in ['baseline', 'bs']: cols = [c for c in self.bs_table.columns] if fuel_type: cols = [c for c in cols if 'simulation_output_report' in c.name] cols = [c for c in cols if fuel_type in c.name] return cols else: tbl = self._get_table(table) return [col for col in tbl.columns]
def get_failed_queries(self, batch_id: int) ‑> tuple[typing.Sequence[buildstock_query.query_core.ExeId], typing.Sequence[str]]
-
summary
Args
batch_id
:int
- Batch query id returned by :py:sumbit_batch_query
Returns
_type_
- tuple of list of failed query execution ids and list of failed queries
Expand source code
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def get_failed_queries(self, batch_id: int) -> tuple[Sequence[ExeId], Sequence[str]]: """_summary_ Args: batch_id (int): Batch query id returned by :py:sumbit_batch_query Returns: _type_: tuple of list of failed query execution ids and list of failed queries """ stats = self._batch_query_status_map.get(batch_id, None) failed_query_ids: list[ExeId] = [] failed_queries: list[str] = [] if stats: for i, exe_id in enumerate(stats['submitted_execution_ids']): completion_stat = self.get_query_status(exe_id) if completion_stat in ['FAILED', 'CANCELLED']: failed_query_ids.append(exe_id) failed_queries.append(stats['submitted_queries'][i]) return failed_query_ids, failed_queries
def get_ids_for_failed_queries(self, batch_id: int) ‑> Sequence[str]
-
Returns the list of execution ids for failed queries in batch query.
Args
batch_id
:int
- batch query id
Returns
Sequence[str]
- List of failed execution ids.
Expand source code
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def get_ids_for_failed_queries(self, batch_id: int) -> Sequence[str]: """Returns the list of execution ids for failed queries in batch query. Args: batch_id (int): batch query id Returns: Sequence[str]: List of failed execution ids. """ failed_ids = [] for i, exe_id in enumerate(self._batch_query_status_map[batch_id]['submitted_execution_ids']): completion_stat = self.get_query_status(exe_id) if completion_stat in ['FAILED', 'CANCELLED']: failed_ids.append(exe_id) return failed_ids
def get_query_error(self, query_id: buildstock_query.query_core.ExeId) ‑> str
-
Returns the error message if query has failed.
Args
query_id
:str
- Query execution id.
Returns
str
- Error message for the query.
Expand source code
@validate_arguments def get_query_error(self, query_id: ExeId) -> str: """Returns the error message if query has failed. Args: query_id (str): Query execution id. Returns: str: Error message for the query. """ stat = self._aws_athena.get_query_execution(QueryExecutionId=query_id) return stat['QueryExecution']['Status']['StateChangeReason']
def get_query_output_location(self, query_id: buildstock_query.query_core.ExeId) ‑> str
-
Get query output location in s3.
Args
query_id
:str
- Query execution id.
Returns
str
- The query location in s3.
Expand source code
@validate_arguments def get_query_output_location(self, query_id: ExeId) -> str: """Get query output location in s3. Args: query_id (str): Query execution id. Returns: str: The query location in s3. """ stat = self._aws_athena.get_query_execution(QueryExecutionId=query_id) output_path = stat['QueryExecution']['ResultConfiguration']['OutputLocation'] return output_path
def get_query_status(self, query_id: buildstock_query.query_core.ExeId) ‑> str
-
Get status of the query
Args
query_id
:str
- Query execution id
Returns
str
- Status of the query.
Expand source code
@validate_arguments def get_query_status(self, query_id: ExeId) -> str: """Get status of the query Args: query_id (str): Query execution id Returns: str: Status of the query. """ stat = self._aws_athena.get_query_execution(QueryExecutionId=query_id) return stat['QueryExecution']['Status']['State']
def get_result_from_s3(self, query_execution_id: buildstock_query.query_core.ExeId) ‑> pandas.core.frame.DataFrame
-
Returns query result from s3 location.
Args
query_execution_id
:str
- The query execution ID
Raises
QueryException
- If query had failed.
Returns
pd.DataFrame
- The query result.
Expand source code
@validate_arguments def get_result_from_s3(self, query_execution_id: ExeId) -> pd.DataFrame: """Returns query result from s3 location. Args: query_execution_id (str): The query execution ID Raises: QueryException: If query had failed. Returns: pd.DataFrame: The query result. """ query_status = self.get_query_status(query_execution_id) if query_status == 'SUCCEEDED': path = self.get_query_output_location(query_execution_id) bucket = path.split('/')[2] key = '/'.join(path.split('/')[3:]) response = self._aws_s3.get_object(Bucket=bucket, Key=key) df = read_csv(response['Body']) return df # If failed, return error message elif query_status == 'FAILED': raise QueryException(self.get_query_error(query_execution_id)) elif query_status in ['RUNNING', 'QUEUED', 'PENDING']: raise QueryException(f"Query still {query_status}") else: raise QueryException(f"Query has unkown status {query_status}")
def load_cache(self, path: Optional[str] = None)
-
Read and update query cache from pickle file.
Args
path
:str
, optional- The path to the pickle file. If not provided, reads from current directory.
Expand source code
@validate_arguments def load_cache(self, path: Optional[str] = None): """Read and update query cache from pickle file. Args: path (str, optional): The path to the pickle file. If not provided, reads from current directory. """ pickle_path = pathlib.Path(path) if path else self._get_cache_file_path() before_count = len(self._query_cache) saved_cache = load_pickle(pickle_path) logger.info(f"{len(saved_cache)} queries cache read from {path}.") self._query_cache.update(saved_cache) self.last_saved_queries = set(saved_cache) after_count = len(self._query_cache) if diff := after_count - before_count: logger.info(f"{diff} queries cache is updated.") else: logger.info("Cache already upto date.")
def print_all_batch_query_status(self) ‑> None
-
Prints the status of all batch queries.
Expand source code
def print_all_batch_query_status(self) -> None: """Prints the status of all batch queries. """ for count in self._batch_query_status_map.keys(): print(f'Query {count}: {self.get_batch_query_report(count)}\n')
def print_failed_query_errors(self, batch_id: int) ‑> None
-
Print the error messages for all queries that failed in batch query.
Args
batch_id
:int
- Batch query id
Expand source code
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def print_failed_query_errors(self, batch_id: int) -> None: """Print the error messages for all queries that failed in batch query. Args: batch_id (int): Batch query id """ failed_ids, failed_queries = self.get_failed_queries(batch_id) for exe_id, query in zip(failed_ids, failed_queries): print(f"Query id: {exe_id}. \n Query string: {query}. Query Ended with: {self.get_query_status(exe_id)}" f"\nError: {self.get_query_error(exe_id)}\n")
def save_cache(self, path: Optional[str] = None, trim_excess: bool = False)
-
Saves queries cache to a pickle file. It is good idea to run this afer making queries so that on the next session these queries won't have to be run on Athena and can be directly loaded from the file.
Args
path
:str
, optional- The path to the pickle file. If not provided, the file will be saved on the current
- directory.
trim_excess
:bool
, optional- If true, any queries in the cache that is not run in current session will be
remved before saving it to file. This is useful if the cache has accumulated a bunch of stray queries over several sessions that are no longer used. Defaults to False.
Expand source code
@validate_arguments def save_cache(self, path: Optional[str] = None, trim_excess: bool = False): """Saves queries cache to a pickle file. It is good idea to run this afer making queries so that on the next session these queries won't have to be run on Athena and can be directly loaded from the file. Args: path (str, optional): The path to the pickle file. If not provided, the file will be saved on the current directory. trim_excess (bool, optional): If true, any queries in the cache that is not run in current session will be remved before saving it to file. This is useful if the cache has accumulated a bunch of stray queries over several sessions that are no longer used. Defaults to False. """ cached_queries = set(self._query_cache) if self.last_saved_queries == cached_queries: logger.info("No new queries to save.") return pickle_path = pathlib.Path(path) if path else self._get_cache_file_path() if trim_excess: if excess_queries := [key for key in self._query_cache if key not in self._session_queries]: for query in excess_queries: del self._query_cache[query] logger.info(f"{len(excess_queries)} excess queries removed from cache.") self.last_saved_queries = cached_queries save_pickle(pickle_path, self._query_cache) logger.info(f"{len(self._query_cache)} queries cache saved to {pickle_path}")
def stop_all_queries(self) ‑> None
-
Stops all queries that are running in Athena for this instance.
Returns
Nothing
Expand source code
def stop_all_queries(self) -> None: """ Stops all queries that are running in Athena for this instance. Returns: Nothing """ for count, stat in self._batch_query_status_map.items(): stat['to_submit_ids'].clear() running_ids = self.get_all_running_queries() for i in running_ids: self.stop_query(execution_id=i) logger.info(f"Stopped {len(running_ids)} queries")
def stop_batch_query(self, batch_id: int) ‑> None
-
Stops all the queries running under a batch query
Args
batch_id
- The batch_id of the batch_query. Returned by :py:sumbit_batch_query
Returns
None
Expand source code
@validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) def stop_batch_query(self, batch_id: int) -> None: """ Stops all the queries running under a batch query Args: batch_id: The batch_id of the batch_query. Returned by :py:sumbit_batch_query Returns: None """ if batch_id not in self._batch_query_status_map: raise ValueError("Batch id not found") self._batch_query_status_map[batch_id]['to_submit_ids'].clear() for exec_id in self._batch_query_status_map[batch_id]['submitted_execution_ids']: self.stop_query(exec_id)
def stop_query(self, execution_id: buildstock_query.query_core.ExeId) ‑> str
-
Stops a running query.
Args
execution_id
- The execution id of the query being run.
Returns:
Expand source code
@validate_arguments def stop_query(self, execution_id: ExeId) -> str: """ Stops a running query. Args: execution_id: The execution id of the query being run. Returns: """ return self._aws_athena.stop_query_execution(QueryExecutionId=execution_id)
def submit_batch_query(self, queries: Sequence[str])
-
Submit multiple related queries
Args
queries
- List of queries to submit. Setting
get_query_only
flag while making calls to aggregation functions is easiest way to obtain queries.
Returns
An integer representing the batch_query id. The id can be used with other batch_query functions.
Expand source code
@validate_arguments def submit_batch_query(self, queries: Sequence[str]): """ Submit multiple related queries Args: queries: List of queries to submit. Setting `get_query_only` flag while making calls to aggregation functions is easiest way to obtain queries. Returns: An integer representing the batch_query id. The id can be used with other batch_query functions. """ queries = list(queries) to_submit_ids = list(range(len(queries))) id_list = list(to_submit_ids) # make a copy submitted_ids: list[int] = [] submitted_execution_ids: list[ExeId] = [] submitted_queries: list[str] = [] queries_futures: list = [] self._batch_query_id += 1 batch_query_id = self._batch_query_id self._batch_query_status_map[batch_query_id] = {'to_submit_ids': to_submit_ids, 'all_ids': list(id_list), 'submitted_ids': submitted_ids, 'submitted_execution_ids': submitted_execution_ids, 'submitted_queries': submitted_queries, 'queries_futures': queries_futures } def run_queries(): while to_submit_ids: current_id = to_submit_ids[0] # get the first one current_query = queries[0] try: execution_id, future = self.execute(current_query, run_async=True) logger.info(f"Submitted queries[{current_id}] ({execution_id})") to_submit_ids.pop(0) # if query queued successfully, remove it from the list queries.pop(0) submitted_ids.append(current_id) submitted_execution_ids.append(ExeId(execution_id)) submitted_queries.append(current_query) queries_futures.append(future) except ClientError as e: if e.response['Error']['Code'] == 'TooManyRequestsException': logger.info("Athena complained about too many requests. Waiting for a minute.") time.sleep(60) # wait for a minute before submitting another query elif e.response['Error']['Code'] == 'InvalidRequestException': logger.info(f"Queries[{current_id}] is Invalid: {e.response['Message']} \n {current_query}") to_submit_ids.pop(0) # query failed, so remove it from the list queries.pop(0) raise else: raise query_runner = Thread(target=run_queries) query_runner.start() return batch_query_id
def wait_for_batch_query(self, batch_id: int)
-
Waits until batch query completes.
Args
batch_id
:int
- The batch query id.
Expand source code
@validate_arguments def wait_for_batch_query(self, batch_id: int): """Waits until batch query completes. Args: batch_id (int): The batch query id. """ sleep_time = 0.5 # start here and keep doubling until max_sleep_time max_sleep_time = 20 while True: last_time = time.time() last_report = None report = self.get_batch_query_report(batch_id) if time.time() - last_time > 60 or last_report is None or report != last_report: logger.info(report) last_report = report last_time = time.time() if report['pending'] == 0 and report['running'] == 0: break time.sleep(sleep_time) sleep_time = min(sleep_time * 2, max_sleep_time)
class QueryException (*args, **kwargs)
-
Common base class for all non-exit exceptions.
Expand source code
class QueryException(Exception): pass
Ancestors
- builtins.Exception
- builtins.BaseException