Source code for dbtk.etl.data_surge

# dbtk/etl/data_surge.py

import logging
import time
import re
from typing import Iterable, Optional

from .base_surge import BaseSurge
from ..utils import batch_iterable
from ..record import Record

logger = logging.getLogger(__name__)


[docs] class DataSurge(BaseSurge): """ Handles bulk ETL operations by delegating to a stateful Table instance. Note: The Table instance's state (self.values) is modified during processing. Ensure the Table is not used concurrently by other operations or threads. Parameters ---------- table : Table Table instance with column definitions and cursor batch_size : int, optional Number of records per batch (default: cursor.batch_size or 1000) use_transaction : bool, optional Wrap all operations in a transaction (default: False) pass_through : bool, optional Skip transformation and validation, using source data directly (default: False). Only compatible for inserts. Not compatible with columns with database expressions `db_expr` **When to use:** - Database-to-database copies with identical schemas - Pre-transformed data from upstream pipelines (already validated) - Raw positional tuples pre-ordered for binding parameters **What's skipped:** Field mapping, type coercion, default values, null value handling, required field validation, and Table.set_values() overhead. **Warning:** Do NOT use if records might have missing required fields, mismatched field names, need type transformations, or data quality is uncertain. All database constraints (primary keys, foreign keys) still apply. Attributes ---------- total_read : int Total rows read from source. 1-based (first row = 1). Includes both loaded and skipped rows. total_loaded : int Total rows successfully loaded. skipped : int Total rows skipped due to missing required fields. skip_details : dict Skip tracking grouped by reason. Key is a frozenset of missing required field names. Value is a dict with: - ``count``: total rows skipped for this reason - ``sample``: list of up to 20 1-based row numbers (for debugging) Example:: {frozenset({'primary_name'}): {'count': 5, 'sample': [937887, 957847, ...]}} Examples -------- Standard ETL with transformation:: table = Table(..., cursor=cursor) surge = DataSurge(table, batch_size=1000, use_transaction=True) errors = surge.insert(records, raise_error=False) Fast database-to-database copy (matching schemas):: # Source and destination schemas match exactly surge = DataSurge(dest_table, batch_size=5000, pass_through=True) surge.insert(source_cursor) Pre-transformed data (already validated):: # Data already transformed and validated by upstream process surge = DataSurge(table, pass_through=True) surge.insert(validated_records) """
[docs] def __init__(self, table, batch_size: Optional[int] = None, use_transaction: bool = False, pass_through: bool = False): """ Initialize DataSurge for bulk operations. Args: table: Table instance with schema metadata batch_size: Number of records per batch use_transaction: Use transaction for all operations (default: False) pass_through: Skip transformation/validation for trusted data (default: False) """ super().__init__(table, batch_size=batch_size, pass_through=pass_through) self.use_transaction = use_transaction # Swap to positional parameters if named to save memory in bind parameters self.table.force_positional() self._sql_statements = {} # Only for modified SQL (merge temp table hack)
[docs] def get_sql(self, operation: str) -> str: """Get SQL for operation, checking local modifications first.""" if operation in self._sql_statements: return self._sql_statements[operation] return self.table.get_sql(operation)
[docs] def insert(self, records: Iterable[Record], raise_error: bool = True) -> int: """Perform bulk INSERT on records.""" return self.load(records, operation="insert", raise_error=raise_error)
[docs] def update(self, records: Iterable[Record], raise_error: bool = True) -> int: """Perform bulk UPDATE on records.""" return self.load(records, operation="update", raise_error=raise_error)
[docs] def delete(self, records: Iterable[Record], raise_error: bool = True) -> int: """Perform bulk DELETE on records.""" return self.load(records, operation="delete", raise_error=raise_error)
[docs] def merge(self, records: Iterable[Record], raise_error: bool = True) -> int: """ Perform bulk MERGE using either direct upsert or temporary table strategy. """ use_upsert = self.table._should_use_upsert() if use_upsert: return self.load(records, operation="merge", raise_error=raise_error) else: return self._merge_with_temp_table(records, raise_error)
def _execute_batches(self, records, operation, sql, raise_error): """Execute batches with executemany.""" errors = 0 skipped = 0 for batch in batch_iterable(records, self.batch_size): batch_params = [] for record in batch: params = self._transform_row(record) if params is None: skipped += 1 continue batch_params.append(params) if batch_params: try: self.cursor.executemany(sql, batch_params) self.total_loaded += len(batch_params) self.table.counts[operation] += len(batch_params) except self.cursor.connection.driver.DatabaseError as e: logger.error(f"{operation.capitalize()} batch failed for {self.table.name}: {str(e)}") if raise_error: raise errors += len(batch_params) return errors, skipped
[docs] def load( self, records: Iterable[Record], operation: Optional[str] = None, raise_error: bool = True, ) -> int: """ Core bulk execution using executemany() — shared path for insert/update/delete/merge. """ self.start_time = time.monotonic() operation = (operation or self.operation).lower() if operation not in ("insert", "update", "delete", "merge"): msg = f"Invalid operation: {operation}" logger.exception(msg) raise ValueError(msg) if self.pass_through: if operation != "insert": msg = f"Operation {operation} is not compatible with pass_through mode." logger.exception(msg) raise ValueError(msg) expr_cols = self.table.db_expr_cols() if expr_cols: msg = f"Columns with `db_expr` are incompatible with pass_through mode. cols: {expr_cols}" logger.exception(msg) raise ValueError(msg) self.operation = operation sql = self.get_sql(operation) if self.use_transaction: with self.cursor.connection.transaction(): errors, skipped = self._execute_batches(records, operation, sql, raise_error) else: errors, skipped = self._execute_batches(records, operation, sql, raise_error) self.skipped += skipped self._log_summary() return errors
def _merge_with_temp_table(self, records: Iterable[Record], raise_error: bool) -> int: """Perform bulk merge using temporary table (for databases requiring true MERGE).""" records_list = list(records) if not records_list: return 0 dialect = self.cursor.connection.dialect col_info = self.table.get_column_definitions(all_cols=dialect.temp_table_all_cols) temp_name, create_sql = dialect.create_temp_table_ddl(self.table.name, col_info) # Reuse existing temp table if present, otherwise create fresh try: self.cursor.execute(f"TRUNCATE TABLE {temp_name}") table_exists = True except self.cursor.connection.driver.DatabaseError: table_exists = False if not table_exists: self.cursor.execute(create_sql) logger.debug(f"Created temp table: {create_sql}") # Use temporary table for bulk insert from .table import Table temp_table = Table( name=temp_name, columns=self.table.columns, cursor=self.cursor, is_temp=True ) temp_surge = DataSurge(temp_table, batch_size=self.batch_size) errors = temp_surge.insert(records_list, raise_error=raise_error) if errors: self.cursor.execute(dialect.cleanup_temp_table_sql(temp_name)) if dialect.temp_table_cleanup_commit: self.cursor.connection.commit() return errors # Transfer record fields from temp table to main table for proper merge column exclusion self.table.calc_update_excludes(temp_table._record_fields) merge_sql = self.table.get_sql('merge') # Replace the USING clause to point to temp table and store modified version modified_merge = re.sub( r'USING\s*\(.*?\)\s*s', f'USING {temp_name} s', merge_sql, flags=re.DOTALL ) self._sql_statements['merge'] = modified_merge logger.debug(f"Modified merge sql: {modified_merge}") try: if self.use_transaction: with self.cursor.connection.transaction(): self.cursor.execute(self.get_sql('merge')) else: self.cursor.execute(self.get_sql('merge')) loaded = len(records_list) - errors self.table.counts['merge'] += loaded logger.info(f"MERGE via temp table → {loaded:,} records into {self.table.name}") except self.cursor.connection.driver.DatabaseError as e: logger.error(f"Merge failed: {e}") if raise_error: raise errors += len(records_list) - errors finally: try: self.cursor.execute(dialect.cleanup_temp_table_sql(temp_name)) except Exception as e: logger.warning(f"Failed to clear temp table {temp_name}: {e}") return errors