Source code for dbtk.etl.managers

# dbtk/etl/managers.py

"""
Orchestration tools for multi-stage, resumable ETL processes.

IdentityManager provides lightweight, incremental identity resolution
for imports where a reliable source primary key exists in source data.
"""

import datetime as dt
import json
import logging
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Union

from ..cursors import PreparedStatement
from .transforms.database import TableLookup
from ..record import Record
from ..utils import ErrorDetail

logger = logging.getLogger(__name__)


[docs] class EntityStatus: """ Status constants for the entity resolution lifecycle. Attributes ---------- PENDING : str Entity has been registered but resolution has not yet been attempted. RESOLVED : str Entity was successfully matched; ``target_key`` is populated. STAGED : str Entity exists in a staging table but has not yet been matched to the target system (e.g. an ERP record not yet confirmed). ERROR : str An error occurred while creating or updating the entity. Possibly impacting downstream processing. SKIPPED : str Resolution was intentionally bypassed for this entity. NOT_FOUND : str Resolution was attempted but no matching record was found in the target. """ PENDING = "pending" RESOLVED = "resolved" STAGED = "staged" ERROR = "error" SKIPPED = "skipped" NOT_FOUND = "not_found" VALUES = ( PENDING, RESOLVED, STAGED, ERROR, SKIPPED, NOT_FOUND, ) @classmethod def __iter__(cls): yield from cls.VALUES
[docs] class IdentityManager: """ Lightweight, resumable identity-resolution cache for ETL imports. Maps source-system primary keys to target-system identifiers using a SQL resolver query. Resolved entities are stored as :class:`dbtk.record.Record` instances keyed by ``source_key`` and enriched with status, messages, errors, and any configured ``alternate_keys``. State can be persisted to JSON between runs with :meth:`save_state` and restored with :meth:`load_state`, allowing long-running and multi-stage imports to be resumed without re-querying already-resolved entities. Parameters ---------- source_key : str Field name for the source-system primary key (e.g. ``'student_id'``). target_key : str Field name for the target-system primary key that the resolver returns (e.g. ``'erp_person_id'``). resolver : PreparedStatement or TableLookup, optional Query used to look up a ``target_key`` from a ``source_key``. Can be set or replaced later via the ``resolver`` property. alternate_keys : list of str, optional Additional key fields to track per entity (e.g. ``['staging_id', 'erp_vendor_id']``). These are persisted alongside ``target_key`` in saved state. Attributes ---------- entities : dict Mapping of source_key value → resolved :class:`dbtk.record.Record`. Each Record contains all resolver columns plus ``_status``, ``_errors``, ``_messages``, and any ``alternate_keys``. Example ------- :: stmt = cursor.prepare_file('sql/resolve_student.sql') im = IdentityManager('student_id', 'erp_person_id', resolver=stmt, alternate_keys=['banner_id']) for row in reader: entity = im.resolve(row) if entity['_status'] == EntityStatus.RESOLVED: table.set_values(row) if table.execute('insert'): # returns 1 on DB error im.add_error(row['student_id'], table.last_error) else: im.add_error(row['student_id'], ErrorDetail('Not found', field='student_id')) im.save_state('state/students.json') # to initialize from saved state em = IdentityManager.load_state('state/students.json', resolver=stmt) em.batch_resolve([EntityStatus.STAGED]) """
[docs] def __init__( self, source_key: str, target_key: str, resolver: Optional[Union[PreparedStatement, TableLookup]] = None, alternate_keys: Optional[List[str]] = None ): """ Initialize IdentityManager. Parameters ---------- source_key : str Field name of the source-system primary key. target_key : str Field name of the target-system primary key. Must be returned by the resolver. Set equal to source_key to skip ID resolution. resolver : PreparedStatement or TableLookup, optional Resolution query. Accepts either type; ``TableLookup`` is unwrapped to its underlying ``PreparedStatement``. alternate_keys : list of str, optional Additional key fields to persist and track per entity. """ self.source_key = source_key self.target_key = target_key self.alternate_keys = alternate_keys if alternate_keys else [] self._lookup: Optional[TableLookup] = None # kept when resolver is a TableLookup self.resolver = resolver # goes through property setter self.entities: Dict[Any, Record] = {} self._record_factory: Optional[type[Record]] = None # lazy-created
@property def resolver(self) -> Optional[PreparedStatement]: """The active PreparedStatement used for resolution queries.""" return self._resolver @resolver.setter def resolver(self, value: Optional[Union[PreparedStatement, TableLookup]]): """ Set the resolver, accepting either a PreparedStatement or TableLookup. Pass ``None`` to clear the resolver (useful when loading state for inspection without re-querying). """ if value is None: self._resolver = None self._lookup = None elif isinstance(value, TableLookup): self._resolver = value._stmt self._lookup = value elif isinstance(value, PreparedStatement): self._resolver = value self._lookup = None else: raise ValueError('Resolver must be either a PreparedStatement or TableLookup') def _setup_record_class(self, record: Optional[Record] = None): """ Create and cache the EntityRecord subclass used for all entities. Derives field list from ``record`` (a freshly resolved row) or from the resolver cursor's current record factory when ``record`` is None. Appends any ``alternate_keys`` not already present, then adds ``_status``, ``_errors``, and ``_messages`` sentinel fields. Called automatically on first resolution; idempotent after that. """ if self._record_factory: return self._record_factory if record is None: if self.resolver and not self.resolver.cursor._row_factory_invalid: temp_class = self.resolver.cursor.record_factory else: temp_class = type('tempEntityRecord', (Record,), {}) keys = [self.source_key] if self.target_key != self.source_key: keys.append(self.target_key) temp_class.set_fields(keys) record = temp_class() alt_keys = [fld for fld in self.alternate_keys if fld not in record] fields = list(record.keys()) + alt_keys + ['_status', '_errors', '_messages'] RecordClass = type('EntityRecord', (Record,), {}) RecordClass.set_fields(fields) if self.target_key not in RecordClass._fields \ and self.target_key not in RecordClass._fields_normalized: raise ValueError(f'{self.target_key} must be returned by the primary resolver') self._record_factory = RecordClass
[docs] def resolve(self, value: Any) -> Optional[Record]: """ Resolve a source key to a target entity, caching the result. Parameters ---------- value : scalar, dict, or Record * **scalar** — treated as the raw ``source_key`` value. The resolver is called with ``{source_key: value}`` and the returned entity is cached but the caller's record is *not* mutated. * **dict or Record** — ``source_key`` is extracted from the mapping. On a successful resolution the ``target_key`` is written back into the caller's record. Returns ------- Record or None The cached/resolved entity Record, or ``None`` if ``source_key`` cannot be found in ``value``. Raises ------ ValueError If the resolved ``target_key`` conflicts with a value already present in the caller's record. """ if isinstance(value, (dict, Record)): source_id = value.get(self.source_key) if source_id is None: return None record = value update_target_key = True else: source_id = value record = None update_target_key = False # Check cache / existing entity if source_id in self.entities: entity = self.entities[source_id] if entity['_status'] == EntityStatus.RESOLVED: resolved_id = entity[self.target_key] if update_target_key and record is not None: existing = record.get(self.target_key) if existing is not None and existing != resolved_id: raise ValueError( f"Conflict on {self.target_key}: existing={existing!r}, resolved={resolved_id!r}" ) record[self.target_key] = resolved_id return entity if self.resolver is None: if self.target_key == self.source_key: status = EntityStatus.RESOLVED else: status = EntityStatus.STAGED if not self._record_factory: self._setup_record_class(record) if record: entity = self._record_factory(**record) entity['_status'] = status entity['_messages'] = [] entity['_errors'] = [] else: entity_dict = {self.source_key: source_id, '_status': status, '_messages': [], '_errors': []} entity = self._record_factory(**entity_dict) self.entities[source_id] = entity return entity # Not cached → run primary resolver # Prefer passing any existing entity data so alternate_keys and partial # results from prior lookups are available to the query. if record: bind_vars = self.resolver.cursor.prepare_params(self.resolver.param_names, record) elif source_id in self.entities: bind_vars = self.resolver.cursor.prepare_params(self.resolver.param_names, self.entities[source_id]) else: bind_vars = {self.source_key: source_id} # If the resolver is backed by an exhaustive preloaded cache, a miss means the # record definitively does not exist — skip the DB round-trip. Only applies # once _record_factory is warm (first hit establishes it via the normal path). if self._lookup and self._lookup.exhaustive and self._record_factory: cache_key = tuple(bind_vars.get(n) for n in self._lookup._key_col_names) if cache_key not in self._lookup._cache: entity = self._record_factory(**{self.source_key: source_id}) entity['_messages'] = [] entity['_errors'] = [] entity['_status'] = EntityStatus.NOT_FOUND self.entities[source_id] = entity return entity self.resolver.execute(bind_vars) resolved_raw = self.resolver.fetchone() if self._record_factory is None: self._setup_record_class(resolved_raw) if resolved_raw is None: resolved_raw = {self.source_key: source_id} entity = self._record_factory(**resolved_raw) entity['_messages'] = [] entity['_errors'] = [] if entity.get(self.target_key) is None: entity['_status'] = EntityStatus.NOT_FOUND else: entity['_status'] = EntityStatus.RESOLVED self.entities[source_id] = entity # Mutate input record if provided if update_target_key and record is not None and entity['_status'] == EntityStatus.RESOLVED: existing = record.get(self.target_key) if existing is not None and existing != entity[self.target_key]: raise ValueError( f"Conflict on {self.target_key}: existing={existing!r}, resolved={entity[self.target_key]!r}" ) record[self.target_key] = entity[self.target_key] return entity
[docs] def add_message(self, source_id: str, message: str): """ Append an informational message to an entity's ``_messages`` list. Parameters ---------- source_id : str Source-system key identifying the entity (must already be cached). message : str Message text to append. """ entity = self.entities[source_id] if entity.get('_messages') is None: entity['_messages'] = [] entity['_messages'].append(message)
[docs] def add_error(self, source_id: str, error: ErrorDetail): """ Append an :class:`dbtk.utils.ErrorDetail` to an entity's ``_errors`` list. Parameters ---------- source_id : str Source-system key identifying the entity (must already be cached). error : ErrorDetail Structured error to attach to the entity. """ entity = self.entities[source_id] if entity.get('_errors') is None: entity['_errors'] = [] entity['_errors'].append(error)
[docs] def set_id(self, source_id: str, id_type: str, value: str): """ Store a target or alternate key value for a cached entity. Parameters ---------- source_id : str Source-system key identifying the entity (must already be cached). id_type : str Either ``target_key`` or one of ``alternate_keys``. value : str The identifier value to store. Raises ------ ValueError If ``id_type`` is not the ``target_key`` or a registered ``alternate_key``. """ if id_type not in self.alternate_keys and id_type != self.target_key: raise ValueError(f'id_type must be either the target_key or one of the alternate_keys') entity = self.entities[source_id] entity[id_type] = value
[docs] def get_id(self, source_id: str, id_type: str): """ Retrieve a target or alternate key value for a cached entity. Parameters ---------- source_id : str Source-system key identifying the entity (must already be cached). id_type : str Either ``target_key`` or one of ``alternate_keys``. Returns ------- str or None The stored identifier value, or ``None`` if not yet set. Raises ------ ValueError If ``id_type`` is not the ``target_key`` or a registered ``alternate_key``. """ if id_type not in self.alternate_keys and id_type != self.target_key: raise ValueError(f'id_type must be either the target_key or one of the alternate_keys') return self.entities[source_id].get(id_type)
[docs] def batch_resolve(self, additional_statuses: Optional[List[str]] = None): """ Re-run the resolver for all entities whose status is PENDING or NOT_FOUND. Useful after bulk-loading staging data when some entities could not be resolved on first pass. Initializes the record factory from a dry-run resolver call if it has not yet been set up. Parameters ---------- additional_statuses : optional list of str Additional statuses to resolve in addition to EntityStatus.NOT_FOUND and EntityStatus.PENDING """ if not self._record_factory: self.resolver.execute({}) self._setup_record_class(None) statuses = {EntityStatus.PENDING, EntityStatus.NOT_FOUND} if additional_statuses: statuses.update(additional_statuses) for source_id, entity in self.entities.items(): if entity.get('_status') in statuses: self.resolve(source_id)
[docs] def calc_stats(self): """ Count entities by status. Returns ------- dict Mapping of each :class:`EntityStatus` value to the number of entities currently at that status. Example ------- :: stats = im.calc_stats() print(stats) # {'pending': 0, 'resolved': 142, 'staged': 5, 'error': 3, ...} """ counts = {s: 0 for s in EntityStatus.VALUES} for entity in self.entities.values(): status = entity.get('_status') if status: counts[status] += 1 return counts
[docs] def save_state(self, path: Union[str, Path]): """ Persist the current entity cache to a JSON file. The file captures ``source_key``, ``target_key``, ``alternate_keys``, ``field_order`` (for factory reconstruction), summary stats, and the full entity dict. :class:`dbtk.utils.ErrorDetail` objects are serialized to ``{"message": ..., "field": ..., "code": ...}`` dicts. Parameters ---------- path : str or Path Destination file path. Parent directory must exist. """ field_order = None if self._record_factory: field_order = self._record_factory._fields # exact ordered list elif self.entities: # Rare fallback: first entity might have partial fields first_entity = next(iter(self.entities.values())) field_order = list(first_entity.keys()) # whatever we have else: field_order = [self.target_key, '_status', '_errors', '_messages'] def _serialize(obj): if isinstance(obj, ErrorDetail): return {'message': obj.message, 'field': obj.field, 'code': obj.code} return str(obj) stats = self.calc_stats() data = { "timestamp": dt.datetime.utcnow().isoformat() + "Z", "source_key": self.source_key, "target_key": self.target_key, "alternate_keys": self.alternate_keys, "field_order": field_order, "stats": stats, "entities": { str(source_pk): entity.to_dict(normalized=False) for source_pk, entity in self.entities.items() } } with open(path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, default=_serialize) logger.info(f"IdentityManager saved state to {path}")
[docs] @classmethod def load_state(cls, path: Union[str, Path], resolver: Optional[Union[PreparedStatement, TableLookup]] = None) -> 'IdentityManager': """ Restore an IdentityManager from a previously saved JSON file. Re-creates the entity Record factory from ``field_order`` stored in the file. Deserializes ``_errors`` lists back to :class:`dbtk.utils.ErrorDetail` instances. Parameters ---------- path : str or Path Path to the JSON file written by :meth:`save_state`. resolver : PreparedStatement or TableLookup, optional Resolver to attach to the restored instance. If the saved file has no ``field_order``, the resolver is used as a fallback to initialize the record factory. Returns ------- IdentityManager Fully restored instance with all entities re-hydrated. """ with open(path, 'r', encoding='utf-8') as f: data = json.load(f) instance = cls( source_key=data["source_key"], target_key=data["target_key"], alternate_keys=data.get("alternate_keys", []), resolver=resolver, ) # Re-create factory from saved field_order field_order = data.get("field_order") if field_order: RecordClass = type('EntityRecord', (Record,), {}) RecordClass.set_fields(field_order) instance._record_factory = RecordClass elif resolver: logger.warning("No field_order in saved state — falling back to resolver") instance.resolver.execute({}) instance._setup_record_class() else: logger.warning("No field_order in saved state and no resolver — entity factory unavailable") instance.entities = {} for source_pk, entity_data in data["entities"].items(): if isinstance(entity_data.get('_errors'), list): entity_data['_errors'] = [ ErrorDetail(**e) if isinstance(e, dict) else e for e in entity_data['_errors'] ] if instance._record_factory: entity = instance._record_factory(**entity_data) else: entity = Record(**entity_data) instance.entities[source_pk] = entity logger.info(f"IdentityManager loaded state from {path}") return instance
[docs] class ValidationCollector: """ Callable collector/enricher for fn pipelines. During row-wise processing: - Collects unique codes - Optionally enriches them with descriptions using TableLookup - Can return a specific field from the lookup result instead of the raw code Supports: - Preload mode: instant enrichment, perfect valid/new split - Lazy mode: enrich on first encounter - No lookup: pure collection Set ``return_col`` to the field name you want returned; ``None`` (default) returns the raw code. """
[docs] def __init__( self, lookup: Optional[TableLookup] = None, return_col: Optional[str] = None, ): self.lookup = lookup self.return_col = return_col self.existing: Dict[Any, Any] = {} # code -> raw lookup result self.added: Dict[Any, Any] = {} # new codes: None until annotated, then dict self._recently_added: bool = False if lookup: self.key_name = ( lookup._key_col_names[0] if isinstance(lookup._key_col_names, (list, tuple)) else lookup._key_col_names ) if lookup._cache_strategy == TableLookup.CACHE_PRELOAD and lookup._preloaded: self._preload_all()
def _preload_all(self): """Populate existing from preloaded cache, storing raw results.""" for cache_key, result in self.lookup._cache.items(): code = cache_key[0] if isinstance(cache_key, tuple) else cache_key self.existing[code] = result def _extract_col(self, result: Any) -> Optional[str]: """Extract return_col from a lookup result, or str-ify a scalar.""" if self.return_col is None: return None if isinstance(result, (str, int, float)): return str(result) if hasattr(result, "get"): val = result.get(self.return_col) return str(val) if val is not None else None if isinstance(result, (tuple, list)): # Fall back to second element when result is a plain sequence return str(result[1]) if len(result) > 1 else None return str(result) def __call__(self, value: Any) -> Any: if value is None: return value self._recently_added = False if isinstance(value, str): raw_codes = [c.strip() for c in value.split(",") if c.strip()] elif isinstance(value, (list, tuple, set)): raw_codes = [c for c in value if c] else: raw_codes = [value] enriched = [] for code in raw_codes: if code in self.existing: col = self._extract_col(self.existing[code]) elif code in self.added: data = self.added[code] col = self._extract_col(data) if data is not None else None else: # Only query DB if lookup exists and isn't preloaded # Preloaded means all valid values are in cache, so cache miss = new value if self.lookup and not self.lookup._preloaded: result = self.lookup({self.key_name: code}) if result: self.existing[code] = result col = self._extract_col(result) else: self.added[code] = None self._recently_added = True col = None else: # No lookup or preloaded (cache miss = definitely new) self.added[code] = None self._recently_added = True col = None enriched.append(col if self.return_col else code) # Return in original format if isinstance(value, str): joined = ",".join(e for e in enriched if e is not None) return joined if joined else None return enriched if isinstance(value, (list, tuple)) else enriched[0] def __contains__(self, value: Any) -> bool: """ Support 'in' operator for validation. Check if a value exists in either existing or added sets. Useful for validating/filtering records based on collected values. Parameters ---------- value : Any The value to check Returns ------- bool True if value exists in either existing or added sets Example ------- :: # Collect titles title_collector = ValidationCollector() for record in titles_reader: title_collector(record['tconst']) # Filter principals based on collected titles with get_reader('title.principals.tsv.gz') as reader: reader.add_filter(lambda r: r.tconst in title_collector) for record in reader: process(record) """ return value in self.existing or value in self.added # Reporting
[docs] def get_valid_mapping(self) -> Dict[Any, Optional[str]]: return {code: self._extract_col(result) for code, result in self.existing.items()}
[docs] def get_all_mapping(self) -> Dict[Any, Optional[str]]: combined = {code: self._extract_col(result) for code, result in self.existing.items()} combined.update({ code: self._extract_col(fields) if fields else None for code, fields in self.added.items() }) return combined
[docs] def collect_new(self, code: Any, **fields) -> None: """ Attach extra fields to a newly-encountered code for later bulk insertion. No-ops immediately when the preceding ``__call__`` did not add a new code (``_recently_added`` is False), so it is safe to call unconditionally on every record. First annotation wins — subsequent calls for the same code are ignored. Parameters ---------- code : Any The code value that was passed to the validator (used as a cross-check and as the key into ``added``). **fields Extra columns to store, e.g. ``stvcipc_desc=record.cip_discipline``. Example ------- :: cip_validator = ValidationCollector(lookup=cip_lookup) for record in reader: stvmajr.set_values(record) # triggers cip_validator(record.cip_code) cip_validator.collect_new(record.cip_code, stvcipc_desc=record.cip_discipline) """ if not self._recently_added: return self._recently_added = False if self.added.get(code) is None: self.added[code] = fields
[docs] def get_all(self) -> set: """ Get all codes (existing + added) as a set. Useful for filtering with tools like polars that need a set/list of valid values rather than a callable. Returns ------- set Union of existing codes and added codes Example ------- :: # Collect valid titles title_collector = ValidationCollector() for record in titles: title_collector(record['tconst']) # Use with polars filtering all_titles = title_collector.get_all() df = pl.scan_csv('principals.tsv.gz').filter( pl.col('tconst').is_in(all_titles) ) # Or with dbtk reader filtering reader.add_filter(lambda r: r.tconst in title_collector) # Uses __contains__ """ return set(self.existing.keys()) | set(self.added.keys())