|
|
""" |
|
|
Database connection module for FleetMind |
|
|
Handles PostgreSQL database connections and initialization |
|
|
""" |
|
|
|
|
|
import psycopg2 |
|
|
import psycopg2.extras |
|
|
import os |
|
|
from typing import Optional, List, Dict, Any |
|
|
from pathlib import Path |
|
|
import logging |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
DB_CONFIG = { |
|
|
'host': os.getenv('DB_HOST', 'localhost'), |
|
|
'port': os.getenv('DB_PORT', '5432'), |
|
|
'database': os.getenv('DB_NAME', 'fleetmind'), |
|
|
'user': os.getenv('DB_USER', 'postgres'), |
|
|
'password': os.getenv('DB_PASSWORD', ''), |
|
|
} |
|
|
|
|
|
|
|
|
def get_db_connection() -> psycopg2.extensions.connection: |
|
|
""" |
|
|
Create and return a PostgreSQL database connection. |
|
|
|
|
|
Returns: |
|
|
psycopg2.connection: Database connection object |
|
|
|
|
|
Raises: |
|
|
psycopg2.Error: If connection fails |
|
|
""" |
|
|
try: |
|
|
conn = psycopg2.connect( |
|
|
host=DB_CONFIG['host'], |
|
|
port=DB_CONFIG['port'], |
|
|
database=DB_CONFIG['database'], |
|
|
user=DB_CONFIG['user'], |
|
|
password=DB_CONFIG['password'], |
|
|
cursor_factory=psycopg2.extras.RealDictCursor, |
|
|
sslmode='prefer' |
|
|
) |
|
|
|
|
|
logger.info(f"Database connection established: {DB_CONFIG['database']}@{DB_CONFIG['host']}") |
|
|
return conn |
|
|
|
|
|
except psycopg2.Error as e: |
|
|
logger.error(f"Database connection error: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def init_database(schema_file: Optional[str] = None) -> None: |
|
|
""" |
|
|
Initialize the database with schema. |
|
|
|
|
|
Args: |
|
|
schema_file: Path to SQL schema file. If None, uses default schema. |
|
|
|
|
|
Raises: |
|
|
psycopg2.Error: If initialization fails |
|
|
""" |
|
|
try: |
|
|
conn = get_db_connection() |
|
|
cursor = conn.cursor() |
|
|
|
|
|
if schema_file and os.path.exists(schema_file): |
|
|
|
|
|
with open(schema_file, 'r') as f: |
|
|
schema_sql = f.read() |
|
|
cursor.execute(schema_sql) |
|
|
logger.info(f"Database initialized from schema file: {schema_file}") |
|
|
else: |
|
|
|
|
|
from .schema import SCHEMA_SQL |
|
|
cursor.execute(SCHEMA_SQL) |
|
|
logger.info("Database initialized with default schema") |
|
|
|
|
|
conn.commit() |
|
|
cursor.close() |
|
|
conn.close() |
|
|
|
|
|
logger.info("Database initialization completed successfully") |
|
|
|
|
|
except psycopg2.Error as e: |
|
|
logger.error(f"Database initialization error: {e}") |
|
|
if conn: |
|
|
conn.rollback() |
|
|
raise |
|
|
|
|
|
|
|
|
def close_connection(conn: psycopg2.extensions.connection) -> None: |
|
|
""" |
|
|
Safely close database connection. |
|
|
|
|
|
Args: |
|
|
conn: Database connection to close |
|
|
""" |
|
|
try: |
|
|
if conn and not conn.closed: |
|
|
conn.close() |
|
|
logger.info("Database connection closed") |
|
|
except psycopg2.Error as e: |
|
|
logger.error(f"Error closing connection: {e}") |
|
|
|
|
|
|
|
|
def execute_query(query: str, params: tuple = ()) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Execute a SELECT query and return results. |
|
|
|
|
|
Args: |
|
|
query: SQL query string |
|
|
params: Query parameters tuple |
|
|
|
|
|
Returns: |
|
|
list: Query results as list of dictionaries |
|
|
|
|
|
Raises: |
|
|
psycopg2.Error: If query fails |
|
|
""" |
|
|
conn = get_db_connection() |
|
|
try: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(query, params) |
|
|
results = cursor.fetchall() |
|
|
cursor.close() |
|
|
return results |
|
|
except psycopg2.Error as e: |
|
|
logger.error(f"Query execution error: {e}") |
|
|
raise |
|
|
finally: |
|
|
close_connection(conn) |
|
|
|
|
|
|
|
|
def execute_write(query: str, params: tuple = ()) -> int: |
|
|
""" |
|
|
Execute an INSERT, UPDATE, or DELETE query. |
|
|
|
|
|
Args: |
|
|
query: SQL query string |
|
|
params: Query parameters tuple |
|
|
|
|
|
Returns: |
|
|
int: Number of rows affected |
|
|
|
|
|
Raises: |
|
|
psycopg2.Error: If query fails |
|
|
""" |
|
|
conn = get_db_connection() |
|
|
try: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(query, params) |
|
|
rows_affected = cursor.rowcount |
|
|
conn.commit() |
|
|
cursor.close() |
|
|
return rows_affected |
|
|
except psycopg2.Error as e: |
|
|
conn.rollback() |
|
|
logger.error(f"Write query error: {e}") |
|
|
raise |
|
|
finally: |
|
|
close_connection(conn) |
|
|
|
|
|
|
|
|
def execute_many(query: str, params_list: List[tuple]) -> int: |
|
|
""" |
|
|
Execute multiple INSERT/UPDATE queries in a batch. |
|
|
|
|
|
Args: |
|
|
query: SQL query string |
|
|
params_list: List of parameter tuples |
|
|
|
|
|
Returns: |
|
|
int: Number of rows affected |
|
|
|
|
|
Raises: |
|
|
psycopg2.Error: If query fails |
|
|
""" |
|
|
conn = get_db_connection() |
|
|
try: |
|
|
cursor = conn.cursor() |
|
|
cursor.executemany(query, params_list) |
|
|
rows_affected = cursor.rowcount |
|
|
conn.commit() |
|
|
cursor.close() |
|
|
return rows_affected |
|
|
except psycopg2.Error as e: |
|
|
conn.rollback() |
|
|
logger.error(f"Batch write error: {e}") |
|
|
raise |
|
|
finally: |
|
|
close_connection(conn) |
|
|
|
|
|
|
|
|
def test_connection() -> bool: |
|
|
""" |
|
|
Test database connection. |
|
|
|
|
|
Returns: |
|
|
bool: True if connection successful, False otherwise |
|
|
""" |
|
|
try: |
|
|
conn = get_db_connection() |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("SELECT version();") |
|
|
version = cursor.fetchone() |
|
|
logger.info(f"PostgreSQL version: {version['version']}") |
|
|
cursor.close() |
|
|
close_connection(conn) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Connection test failed: {e}") |
|
|
return False |
|
|
|