# dbtk/etl/data_surge.py
import logging
import time
import re
from typing import Iterable, Optional
from .base_surge import BaseSurge
from ..utils import batch_iterable
from ..record import Record
logger = logging.getLogger(__name__)
[docs]
class DataSurge(BaseSurge):
"""
Handles bulk ETL operations by delegating to a stateful Table instance.
Note: The Table instance's state (self.values) is modified during processing.
Ensure the Table is not used concurrently by other operations or threads.
Parameters
----------
table : Table
Table instance with column definitions and cursor
batch_size : int, optional
Number of records per batch (default: cursor.batch_size or 1000)
use_transaction : bool, optional
Wrap all operations in a transaction (default: False)
pass_through : bool, optional
Skip transformation and validation, using source data directly (default: False).
Only compatible for inserts. Not compatible with columns with database expressions `db_expr`
**When to use:**
- Database-to-database copies with identical schemas
- Pre-transformed data from upstream pipelines (already validated)
- Raw positional tuples pre-ordered for binding parameters
**What's skipped:** Field mapping, type coercion, default values, null value
handling, required field validation, and Table.set_values() overhead.
**Warning:** Do NOT use if records might have missing required fields, mismatched
field names, need type transformations, or data quality is uncertain. All database
constraints (primary keys, foreign keys) still apply.
Attributes
----------
total_read : int
Total rows read from source. 1-based (first row = 1). Includes
both loaded and skipped rows.
total_loaded : int
Total rows successfully loaded.
skipped : int
Total rows skipped due to missing required fields.
skip_details : dict
Skip tracking grouped by reason. Key is a frozenset of missing
required field names. Value is a dict with:
- ``count``: total rows skipped for this reason
- ``sample``: list of up to 20 1-based row numbers (for debugging)
Example::
{frozenset({'primary_name'}): {'count': 5, 'sample': [937887, 957847, ...]}}
Examples
--------
Standard ETL with transformation::
table = Table(..., cursor=cursor)
surge = DataSurge(table, batch_size=1000, use_transaction=True)
errors = surge.insert(records, raise_error=False)
Fast database-to-database copy (matching schemas)::
# Source and destination schemas match exactly
surge = DataSurge(dest_table, batch_size=5000, pass_through=True)
surge.insert(source_cursor)
Pre-transformed data (already validated)::
# Data already transformed and validated by upstream process
surge = DataSurge(table, pass_through=True)
surge.insert(validated_records)
"""
[docs]
def __init__(self, table, batch_size: Optional[int] = None, use_transaction: bool = False, pass_through: bool = False):
"""
Initialize DataSurge for bulk operations.
Args:
table: Table instance with schema metadata
batch_size: Number of records per batch
use_transaction: Use transaction for all operations (default: False)
pass_through: Skip transformation/validation for trusted data (default: False)
"""
super().__init__(table, batch_size=batch_size, pass_through=pass_through)
self.use_transaction = use_transaction
# Swap to positional parameters if named to save memory in bind parameters
self.table.force_positional()
self._sql_statements = {} # Only for modified SQL (merge temp table hack)
[docs]
def get_sql(self, operation: str) -> str:
"""Get SQL for operation, checking local modifications first."""
if operation in self._sql_statements:
return self._sql_statements[operation]
return self.table.get_sql(operation)
[docs]
def insert(self, records: Iterable[Record], raise_error: bool = True) -> int:
"""Perform bulk INSERT on records."""
return self.load(records, operation="insert", raise_error=raise_error)
[docs]
def update(self, records: Iterable[Record], raise_error: bool = True) -> int:
"""Perform bulk UPDATE on records."""
return self.load(records, operation="update", raise_error=raise_error)
[docs]
def delete(self, records: Iterable[Record], raise_error: bool = True) -> int:
"""Perform bulk DELETE on records."""
return self.load(records, operation="delete", raise_error=raise_error)
[docs]
def merge(self, records: Iterable[Record], raise_error: bool = True) -> int:
"""
Perform bulk MERGE using either direct upsert or temporary table strategy.
"""
use_upsert = self.table._should_use_upsert()
if use_upsert:
return self.load(records, operation="merge", raise_error=raise_error)
else:
return self._merge_with_temp_table(records, raise_error)
def _execute_batches(self, records, operation, sql, raise_error):
"""Execute batches with executemany."""
errors = 0
skipped = 0
for batch in batch_iterable(records, self.batch_size):
batch_params = []
for record in batch:
params = self._transform_row(record)
if params is None:
skipped += 1
continue
batch_params.append(params)
if batch_params:
try:
self.cursor.executemany(sql, batch_params)
self.total_loaded += len(batch_params)
self.table.counts[operation] += len(batch_params)
except self.cursor.connection.driver.DatabaseError as e:
logger.error(f"{operation.capitalize()} batch failed for {self.table.name}: {str(e)}")
if raise_error:
raise
errors += len(batch_params)
return errors, skipped
[docs]
def load(
self,
records: Iterable[Record],
operation: Optional[str] = None,
raise_error: bool = True,
) -> int:
"""
Core bulk execution using executemany() — shared path for insert/update/delete/merge.
"""
self.start_time = time.monotonic()
operation = (operation or self.operation).lower()
if operation not in ("insert", "update", "delete", "merge"):
msg = f"Invalid operation: {operation}"
logger.exception(msg)
raise ValueError(msg)
if self.pass_through:
if operation != "insert":
msg = f"Operation {operation} is not compatible with pass_through mode."
logger.exception(msg)
raise ValueError(msg)
expr_cols = self.table.db_expr_cols()
if expr_cols:
msg = f"Columns with `db_expr` are incompatible with pass_through mode. cols: {expr_cols}"
logger.exception(msg)
raise ValueError(msg)
self.operation = operation
sql = self.get_sql(operation)
if self.use_transaction:
with self.cursor.connection.transaction():
errors, skipped = self._execute_batches(records, operation, sql, raise_error)
else:
errors, skipped = self._execute_batches(records, operation, sql, raise_error)
self.skipped += skipped
self._log_summary()
return errors
def _merge_with_temp_table(self, records: Iterable[Record], raise_error: bool) -> int:
"""Perform bulk merge using temporary table (for databases requiring true MERGE)."""
records_list = list(records)
if not records_list:
return 0
dialect = self.cursor.connection.dialect
col_info = self.table.get_column_definitions(all_cols=dialect.temp_table_all_cols)
temp_name, create_sql = dialect.create_temp_table_ddl(self.table.name, col_info)
# Reuse existing temp table if present, otherwise create fresh
try:
self.cursor.execute(f"TRUNCATE TABLE {temp_name}")
table_exists = True
except self.cursor.connection.driver.DatabaseError:
table_exists = False
if not table_exists:
self.cursor.execute(create_sql)
logger.debug(f"Created temp table: {create_sql}")
# Use temporary table for bulk insert
from .table import Table
temp_table = Table(
name=temp_name,
columns=self.table.columns,
cursor=self.cursor,
is_temp=True
)
temp_surge = DataSurge(temp_table, batch_size=self.batch_size)
errors = temp_surge.insert(records_list, raise_error=raise_error)
if errors:
self.cursor.execute(dialect.cleanup_temp_table_sql(temp_name))
if dialect.temp_table_cleanup_commit:
self.cursor.connection.commit()
return errors
# Transfer record fields from temp table to main table for proper merge column exclusion
self.table.calc_update_excludes(temp_table._record_fields)
merge_sql = self.table.get_sql('merge')
# Replace the USING clause to point to temp table and store modified version
modified_merge = re.sub(
r'USING\s*\(.*?\)\s*s',
f'USING {temp_name} s',
merge_sql,
flags=re.DOTALL
)
self._sql_statements['merge'] = modified_merge
logger.debug(f"Modified merge sql: {modified_merge}")
try:
if self.use_transaction:
with self.cursor.connection.transaction():
self.cursor.execute(self.get_sql('merge'))
else:
self.cursor.execute(self.get_sql('merge'))
loaded = len(records_list) - errors
self.table.counts['merge'] += loaded
logger.info(f"MERGE via temp table → {loaded:,} records into {self.table.name}")
except self.cursor.connection.driver.DatabaseError as e:
logger.error(f"Merge failed: {e}")
if raise_error:
raise
errors += len(records_list) - errors
finally:
try:
self.cursor.execute(dialect.cleanup_temp_table_sql(temp_name))
except Exception as e:
logger.warning(f"Failed to clear temp table {temp_name}: {e}")
return errors