# dbtk/etl/transforms/database.py
"""
Database table lookup and validation using PreparedStatement.
"""
import logging
from typing import Any, Union, List, Optional, Callable
from ...utils import quote_identifier, validate_identifier
from ...cursors import PreparedStatement
logger = logging.getLogger(__name__)
[docs]
class TableLookup:
"""
Database table lookup with configurable caching for ETL transformations.
Performs lookups against database tables or views using PreparedStatement for
efficient repeated queries. Supports three caching strategies:
- CACHE_NONE (0): No caching, always query database
- CACHE_LAZY (1): Cache results as encountered (default)
- CACHE_PRELOAD (2): Preload entire table into memory upfront
Can operate in three modes:
- Validator: No return_cols specified, returns bool indicating existence
- Single lookup: One return column, returns scalar value
- Multi lookup: Multiple return columns, returns row object (type depends on cursor)
Examples
--------
Validator mode::
temple_validator = TableLookup(cursor, 'temples', key_cols='temple_id')
is_valid = temple_validator({'temple_id': 'eastern_air_temple'}) # True/False
Single value lookup::
state_lookup = TableLookup(cursor, 'states',
key_cols='name',
return_cols='abbreviation',
cache=TableLookup.CACHE_PRELOAD) # Small table, preload it
abbrev = state_lookup({'name': 'California'}) # 'CA'
Multi-value lookup::
address_lookup = TableLookup(cursor, 'addresses',
key_cols='address_id',
return_cols=['street', 'city', 'state', 'zip'])
address = address_lookup({'address_id': 123}) # Record/dict/namedtuple
Multi-key lookup::
person_lookup = TableLookup(cursor, 'people',
key_cols=['first_name', 'last_name'],
return_cols='person_id',
cache=TableLookup.CACHE_NONE) # Large table, don't cache
person_id = person_lookup({'first_name': 'Aang', 'last_name': 'Avatar'})
Use in Table config::
table = Table('citizens', {
'state_abbrev': {'field': 'state_name', 'fn': state_lookup}
}, cursor=cursor)
"""
# Cache strategy constants
CACHE_NONE = 0 # No caching
CACHE_LAZY = 1 # Cache as encountered (default)
CACHE_PRELOAD = 2 # Preload entire table
[docs]
def __init__(self,
cursor,
table: str,
key_cols: Union[str, List[str]],
return_cols: Optional[Union[str, List[str]]] = None,
cache: int = 1): # Default to lazy caching
"""
Initialize TableLookup with table schema and caching strategy.
Parameters
----------
cursor : Cursor
Database cursor for executing queries
table : str
Table or view name to query
key_cols : str or list of str
Column name(s) to use in WHERE clause. Cannot be empty.
return_cols : str, list of str, or None
Column(s) to return. If None, operates as validator (returns bool)
cache : int, default 1 (CACHE_LAZY)
Caching strategy:
- 0 (CACHE_NONE): No caching, always query database
- 1 (CACHE_LAZY): Cache results as encountered
- 2 (CACHE_PRELOAD): Preload entire table into memory
Raises
------
ValueError
If key_cols is empty or cache value is invalid
"""
# Validate cache strategy
if cache not in (self.CACHE_NONE, self.CACHE_LAZY, self.CACHE_PRELOAD):
raise ValueError(
f"Invalid cache value: {cache}. "
f"Use TableLookup.CACHE_NONE (0), TableLookup.CACHE_LAZY (1), or TableLookup.CACHE_PRELOAD (2)"
)
self._cursor = cursor
# Validate and quote identifiers
validate_identifier(table)
self._table = quote_identifier(table)
# Normalize key_cols to list and validate
if isinstance(key_cols, str):
key_cols = [key_cols]
if not key_cols:
raise ValueError("key_cols cannot be empty")
for col in key_cols:
validate_identifier(col)
self._key_cols = [quote_identifier(col) for col in key_cols]
self._key_col_names = key_cols # Store unquoted for bind params
# Normalize return_cols to list or None
if return_cols is None:
self._return_cols = None
self._validator_mode = True
else:
if isinstance(return_cols, str):
return_cols = [return_cols]
for col in return_cols:
validate_identifier(col)
self._return_cols = [quote_identifier(col) for col in return_cols]
self._validator_mode = False
self._cache_strategy = cache
self._cache = {}
self._preloaded = False
self.exhaustive: bool = (cache == self.CACHE_PRELOAD)
# Build SQL
self._build_sql()
# Apply caching strategy
if cache == self.CACHE_NONE:
# No caching at all
pass
elif cache == self.CACHE_LAZY:
# Lazy cache - enabled but don't preload
pass # Cache dict already initialized
elif cache == self.CACHE_PRELOAD:
# Preload entire table - user's responsibility to know table size
self._preload_all()
self._preloaded = True
@property
def cursor(self):
return self._cursor
def _build_sql(self):
"""Build the SELECT and PreparedStatement."""
# Build SELECT clause
if self._validator_mode:
# For validation, just select one of the key columns
select_clause = self._key_cols[0]
else:
select_clause = ', '.join(self._return_cols)
# Build WHERE clause with named parameters
where_conditions = []
for col, param_name in zip(self._key_cols, self._key_col_names):
where_conditions.append(f"{col} = :{param_name}")
where_clause = ' AND '.join(where_conditions)
# Complete SQL
query = f"SELECT {select_clause} FROM {self._table} WHERE {where_clause}"
logger.debug(f"TableLookup query: {query}")
# Create PreparedStatement
self._stmt = PreparedStatement(self._cursor, query=query)
def _preload_all(self):
"""Preload entire lookup table into cache."""
if self._validator_mode:
select_cols = ', '.join(self._key_cols)
query = f"SELECT {select_cols} FROM {self._table}"
self._cursor.execute(query)
num_keys = len(self._key_cols)
for row in self._cursor.fetchall():
key_values = row[:num_keys]
if any(k in (None, '') for k in key_values):
continue
self._cache[tuple(key_values)] = True
return
# Prepend only key cols not already in return_cols to avoid duplicate columns.
extra_key_cols = [col for col in self._key_cols if col not in self._return_cols]
all_cols = extra_key_cols + self._return_cols
query = f"SELECT {', '.join(all_cols)} FROM {self._table}"
self._cursor.execute(query)
for row in self._cursor.fetchall():
key_values = tuple(
row[all_cols.index(col)] for col in self._key_cols
)
if any(k in (None, '') for k in key_values):
continue
if len(self._return_cols) == 1:
self._cache[key_values] = row[len(extra_key_cols)]
elif not extra_key_cols:
self._cache[key_values] = row # row IS the return_cols Record
else:
self._cache[key_values] = tuple(row[len(extra_key_cols):])
def _make_cache_key(self, bind_vars: dict) -> tuple:
"""Create cache key from bind variables."""
return tuple(bind_vars[name] for name in self._key_col_names)
def _lookup(self, bind_vars: dict) -> Any:
"""Perform database lookup."""
self._stmt.execute(bind_vars)
row = self._stmt.fetchone()
if not row:
return None
if self._validator_mode:
return True
elif len(self._return_cols) == 1:
# Single column - return scalar at position 0
return row[0]
else:
# Multiple columns - return whole row
return row
def __call__(self, bind_vars: dict) -> Any:
"""
Perform lookup with given key values.
Parameters
----------
bind_vars : dict
Dictionary with key column names and their values
Returns
-------
bool
If in validator mode (no return_cols)
scalar
If single return column
row object
If multiple return columns (type depends on cursor)
None
If no match found
Raises
------
ValueError
If required key columns are missing from bind_vars
"""
# Check that all required keys are present in bind_vars
missing_keys = [k for k in self._key_col_names if k not in bind_vars]
if missing_keys:
raise ValueError(f"TableLookup for '{self._table}' missing required keys: {missing_keys}.")
# Check for None/empty values in keys (data quality issue, not an error)
for key_name in self._key_col_names:
if bind_vars[key_name] in (None, ''):
return False if self._validator_mode else None
# No caching - always query
if self._cache_strategy == self.CACHE_NONE:
result = self._lookup(bind_vars)
if result is None:
return False if self._validator_mode else None
return result
# With caching - check cache first
cache_key = self._make_cache_key(bind_vars)
if cache_key in self._cache:
return self._cache[cache_key]
# Not in cache - perform lookup
result = self._lookup(bind_vars)
# Cache result
if result is not None:
self._cache[cache_key] = result
else:
result = False if self._validator_mode else None
if not self._preloaded:
# Only cache misses if not preloaded (preloaded means we know all valid keys)
self._cache[cache_key] = result
return result
[docs]
def Lookup(table: str,
key_cols: Union[str, List[str]],
return_cols: Union[str, List[str]],
*,
cache: int = TableLookup.CACHE_LAZY,
missing: Any = None) -> Callable[[Any], Any]:
"""
One-liner database lookup for Table column configs.
"""
return _DeferredTransform.create_lookup(
table=table,
key_cols=key_cols,
return_cols=return_cols,
cache=cache,
missing=missing
)
[docs]
def Validate(table: str,
key_cols: Union[str, List[str]],
*,
cache: int = TableLookup.CACHE_LAZY,
on_fail: str = 'warn') -> Callable[[Any], Any]:
"""
One-liner validation — returns original value if key exists in table.
"""
return _DeferredTransform.create_validator(
table=table,
key_cols=key_cols,
cache=cache,
on_fail=on_fail
)
# ——————————————————— Internal Implementation ———————————————————
class _DeferredTransform:
__slots__ = ('_args', '_kwargs', '_extra', '_bound_fn')
def __init__(self, args, kwargs, extra=None):
self._args = args
self._kwargs = kwargs
self._extra = extra or {}
self._bound_fn = None
@classmethod
def create_lookup(cls, table, key_cols, return_cols, *, cache=TableLookup.CACHE_LAZY, missing=None):
# Only pass valid TableLookup args
return cls(
args=(table, key_cols, return_cols),
kwargs={'cache': cache}, # ← only this goes to TableLookup
extra={'missing': missing} # ← our wrapper-specific option
)
@classmethod
def create_validator(cls, table, key_cols, *, cache=TableLookup.CACHE_LAZY, on_fail='warn'):
return cls(
args=(table, key_cols),
kwargs={'cache': cache},
extra={'on_fail': on_fail, 'validator': True}
)
@classmethod
def from_string(cls, spec: str):
"""
Parse string shorthand for Lookup/Validate transforms.
Formats:
'lookup:table:key_col:return_col[:cache]'
'validate:table:key_col[:cache]'
Cache can be:
- 0 or 'no_cache' → CACHE_NONE
- 1 or 'lazy' → CACHE_LAZY (default)
- 2 or 'pre_cache' → CACHE_PRELOAD
Example
-------
::
'lookup:states:code:state'
'lookup:states:code:state:2'
'lookup:states:code:state:preload'
'lookup:customers:person_id,store_id:customer_id'
'validate:regions:name'
'validate:regions:name:no_cache'
"""
parts = [p.strip() for p in spec.split(':')]
if len(parts) < 2 or len(parts) > 5:
raise ValueError(f"Invalid transform spec: '{spec}'")
transform_type = parts[0].lower()
# Parse cache option helper
def parse_cache(cache_str: str) -> int:
"""Convert cache string/int to TableLookup constant."""
cache_lower = cache_str.lower()
if cache_lower in ('0', 'no_cache', 'none'):
return TableLookup.CACHE_NONE
elif cache_lower in ('1', 'lazy'):
return TableLookup.CACHE_LAZY
elif cache_lower in ('2', 'pre_cache', 'preload', 'precache'):
return TableLookup.CACHE_PRELOAD
else:
raise ValueError(
f"Invalid cache value: '{cache_str}'. "
f"Use 0/'no_cache', 1/'lazy', or 2/'pre_cache'"
)
if transform_type == 'lookup':
if len(parts) < 4:
raise ValueError(
f"Invalid lookup spec: '{spec}'. "
"Expected 'lookup:table:key_col:return_col[:cache]'"
)
table = parts[1]
key_col = parts[2].split(',') if ',' in parts[2] else parts[2]
return_col = parts[3].split(',') if ',' in parts[3] else parts[3]
cache = parse_cache(parts[4]) if len(parts) > 4 else TableLookup.CACHE_LAZY
return cls.create_lookup(table, key_col, return_col, cache=cache)
elif transform_type == 'validate':
if len(parts) < 3:
raise ValueError(
f"Invalid validate spec: '{spec}'. "
"Expected 'validate:table:key_col[:cache]'"
)
table = parts[1]
key_col = parts[2].split(',') if ',' in parts[2] else parts[2]
cache = parse_cache(parts[3]) if len(parts) > 3 else TableLookup.CACHE_LAZY
return cls.create_validator(table, key_col, cache=cache)
else:
raise ValueError(
f"Unknown transform type: '{transform_type}'. "
"Must be 'lookup' or 'validate'"
)
def bind(self, cursor) -> Callable[[Any], Any]:
if self._bound_fn is not None:
return self._bound_fn
is_validator = self._extra.get('validator', False)
lookup = TableLookup(cursor, *self._args, **self._kwargs) # ← clean, no junk
key_cols = self._args[1]
if is_validator:
on_fail = self._extra.get('on_fail', 'warn')
def validate(value):
if value is None:
return value
bind_vars = _make_bind_vars(key_cols, value)
if not lookup(bind_vars):
msg = f"Validation failed: {bind_vars} not found in table {self._args[0]}"
if on_fail == 'raise':
raise ValueError(msg)
elif on_fail == 'warn':
logger.warning(msg)
return value
self._bound_fn = validate
else:
missing = self._extra.get('missing')
def transform(value):
if value is None:
return missing
bind_vars = _make_bind_vars(key_cols, value)
result = lookup(bind_vars)
return result if result is not None else missing
self._bound_fn = transform
return self._bound_fn
def __call__(self, value):
if self._bound_fn is None:
raise RuntimeError("Lookup/Validate used before Table cursor was bound")
return self._bound_fn(value)
[docs]
class QueryLookup:
"""
Deferred PreparedStatement lookup for use as a Table column transform.
Accepts a SQL query string or file path. Cursor binding is deferred until
the Table is initialized with a cursor. Use with ``field='*'`` to pass the
full source row as bind variables — PreparedStatement ignores extra keys.
Parameters
----------
query : str, optional
Inline SQL string.
filename : str or Path, optional
Path to a SQL file.
return_col : str, optional
Column name to extract from the result row.
By default, the first column will be returned. Set to '*' if you want to return
multiple values to the next stage of the pipeline.
missing : any, optional
Value to return when the query returns no rows. Default ``None``.
Example
-------
::
id_lookup_query = '''
SELECT p.id, p.last_name, p.first_name
FROM people p
LEFT JOIN employees e ON e.id = p.id
WHERE (p.email = :email
OR e.tax_id = :tax_id)
'''
emp_cols = {
'id': {'field': '*', 'primary_key': True, 'fn': dbtk.etl.QueryLookup(query=id_lookup_query)},
'first_name': {'field': 'first_name', 'required': True},'
...
}
"""
[docs]
def __init__(self, query=None, filename=None, return_col=None, missing=None):
if query is None and filename is None:
raise ValueError("QueryLookup requires either 'query' or 'filename'")
self.query = query
self.filename = filename
self.return_col = return_col
self.missing = missing
[docs]
def bind(self, cursor):
stmt = PreparedStatement(cursor, query=self.query, filename=self.filename)
return_col = self.return_col
missing = self.missing
def fn(value):
bind_vars = _make_bind_vars(stmt.param_names, value)
stmt.execute(bind_vars)
row = stmt.fetchone()
if row is None:
return missing
if return_col is not None:
if return_col == '*':
return row
return row[return_col]
return row[0]
return fn
def _make_bind_vars(key_cols_spec: Union[str, List[str]], value: Any) -> dict:
if isinstance(key_cols_spec, str):
return {key_cols_spec: value}
# key_cols_spec is list/tuple
if hasattr(value, 'items'):
return value
if hasattr(value, '__iter__') and not isinstance(value, str):
return dict(zip(key_cols_spec, value))
# Fallback: single value, multiple keys → use first key
return {key_cols_spec[0]: value}