Source code for dbtk.writers.database

# dbtk/writers/database.py
"""
Database writing utilities and ETL table operations.
"""

import logging
from typing import Optional, Tuple

from .base import BaseWriter
from ..database import ParamStyle
from ..utils import process_sql_parameters

logger = logging.getLogger(__name__)


[docs] class DatabaseWriter(BaseWriter): """Database writer that extends BaseWriter."""
[docs] def __init__(self, data, target_cursor, target_table: str, batch_size: int = 1000, commit_frequency: int = 10000): """ Initialize database writer. Args: data: Source cursor or list of records target_cursor: Target database cursor target_table: Name of target table batch_size: Number of records to insert per batch commit_frequency: How often to commit (in number of records) """ # Call BaseWriter with no file (we're writing to database) super().__init__(data, file=None) self.target_cursor = target_cursor self.target_table = target_table self.batch_size = batch_size self.commit_frequency = commit_frequency # Get parameter style from target connection self.paramstyle = target_cursor.connection.driver.paramstyle # Pre-generate INSERT statement and param order self.insert_sql, self.param_names = self._create_insert_statement()
def _create_insert_statement(self) -> Tuple[str, Tuple[str, ...]]: """Create INSERT statement for the target table.""" # Build with named parameters columns_str = ', '.join(self.columns) params_str = ', '.join([f':{col}' for col in self.columns]) sql = f'INSERT INTO {self.target_table} ({columns_str}) VALUES ({params_str})' # Convert to target paramstyle and get parameter order return process_sql_parameters(sql, self.paramstyle) def _write_data(self, file_obj) -> None: """Write data to database using batched inserts.""" # file_obj is None and unused - we're writing to database logger.info(f"Starting copy to {self.target_table}") logger.debug(f"Using INSERT statement: {self.insert_sql}") batch = [] try: for record in self.data_iterator: # Extract values from record values = self._row_to_tuple(record) # Handle parameter style if self.paramstyle in ParamStyle.named_styles(): # Convert to dict for named parameters, using param_names order params = {name: values[i] for i, name in enumerate(self.param_names)} batch.append(params) else: # Positional parameters - values already in correct order batch.append(tuple(values)) # Execute batch when it reaches batch_size if len(batch) >= self.batch_size: if len(batch) == 1: self.target_cursor.execute(self.insert_sql, batch[0]) else: self.target_cursor.executemany(self.insert_sql, batch) self._row_num += len(batch) batch = [] # Commit periodically if self._row_num % self.commit_frequency == 0: self.target_cursor.connection.commit() logger.info(f"Committed {self._row_num} records") # Execute remaining batch if batch: if len(batch) == 1: self.target_cursor.execute(self.insert_sql, batch[0]) else: self.target_cursor.executemany(self.insert_sql, batch) self._row_num += len(batch) # Final commit self.target_cursor.connection.commit() logger.info(f"Copy completed: {self._row_num} records inserted into {self.target_table}") except self.target_cursor.connection.driver.DatabaseError as e: logger.error(f"Error during copy: {e}") self.target_cursor.connection.rollback() raise
[docs] def write(self) -> int: """Override to bypass file handle creation and call _write_data directly.""" try: self._write_data(None) # Pass None since we don't need file_obj return self._row_num except Exception as e: logger.error(f"Error writing database data: {e}") raise
[docs] def cursor_to_cursor(source_data, target_cursor, target_table: str, batch_size: int = 1000, commit_frequency: int = 10000) -> int: """ Copy data from source cursor/results to target database table. Args: source_data: Source cursor or list of records target_cursor: Target database cursor target_table: Name of target table batch_size: Number of records to insert per batch commit_frequency: How often to commit (in number of records) Returns: Number of records inserted Example: # Copy between databases source_cursor = source_db.cursor() source_cursor.execute("SELECT * FROM users") target_cursor = target_db.cursor() count = cursor_to_cursor(source_cursor, target_cursor, 'users_copy') print(f"Copied {count} records") """ writer = DatabaseWriter( data=source_data, target_cursor=target_cursor, target_table=target_table, batch_size=batch_size, commit_frequency=commit_frequency ) return writer.write()