|
|
import duckdb |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
from typing import Dict, Any, Optional, List |
|
|
from backend.core.data_catalog import get_data_catalog |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class GeoEngine: |
|
|
_instance = None |
|
|
|
|
|
def __new__(cls): |
|
|
if cls._instance is None: |
|
|
cls._instance = super(GeoEngine, cls).__new__(cls) |
|
|
cls._instance.initialized = False |
|
|
return cls._instance |
|
|
|
|
|
def __init__(self): |
|
|
if self.initialized: |
|
|
return |
|
|
|
|
|
logger.info("Initializing GeoEngine (DuckDB)...") |
|
|
try: |
|
|
self.con = duckdb.connect(database=':memory:') |
|
|
self.con.install_extension('spatial') |
|
|
self.con.load_extension('spatial') |
|
|
logger.info("GeoEngine initialized with Spatial extension.") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize GeoEngine: {e}") |
|
|
raise e |
|
|
|
|
|
self.layers = {} |
|
|
self.catalog = get_data_catalog() |
|
|
self.base_tables_loaded = False |
|
|
self.initialized = True |
|
|
|
|
|
|
|
|
self.initialize_base_tables() |
|
|
|
|
|
def initialize_base_tables(self): |
|
|
""" |
|
|
Load essential administrative boundary files into DuckDB tables. |
|
|
""" |
|
|
if self.base_tables_loaded: |
|
|
return |
|
|
|
|
|
logger.info("Loading base tables into DuckDB...") |
|
|
|
|
|
|
|
|
|
|
|
base_tables = [ |
|
|
name for name, meta in self.catalog.catalog.items() |
|
|
if meta.get('category') == 'base' |
|
|
] |
|
|
|
|
|
for table_name in base_tables: |
|
|
self.ensure_table_loaded(table_name) |
|
|
|
|
|
self.base_tables_loaded = True |
|
|
logger.info("Base tables loaded.") |
|
|
|
|
|
def ensure_table_loaded(self, table_name: str) -> bool: |
|
|
""" |
|
|
Ensure a table is loaded in DuckDB. If not, load it from the catalog. |
|
|
Returns True if successful, False otherwise. |
|
|
""" |
|
|
|
|
|
try: |
|
|
self.con.execute(f"DESCRIBE {table_name}") |
|
|
return True |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
file_path = self.catalog.get_file_path(table_name) |
|
|
if not file_path or not file_path.exists(): |
|
|
logger.warning(f"Table {table_name} not found in catalog or file missing.") |
|
|
return False |
|
|
|
|
|
try: |
|
|
logger.info(f"Lazy loading table: {table_name}") |
|
|
self.con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM ST_Read('{file_path}')") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load {table_name}: {e}") |
|
|
return False |
|
|
|
|
|
def get_table_schemas(self) -> str: |
|
|
""" |
|
|
Get schema of currently loaded tables for LLM context. |
|
|
""" |
|
|
result = "Currently Loaded Tables:\n\n" |
|
|
|
|
|
try: |
|
|
|
|
|
tables = self.con.execute("SHOW TABLES").fetchall() |
|
|
for table in tables: |
|
|
table_name = table[0] |
|
|
try: |
|
|
columns = self.con.execute(f"DESCRIBE {table_name}").fetchall() |
|
|
row_count = self.con.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0] |
|
|
|
|
|
result += f"### {table_name} ({row_count} rows)\n" |
|
|
result += "Columns:\n" |
|
|
|
|
|
for col in columns: |
|
|
col_name, col_type = col[0], col[1] |
|
|
if col_name == 'geom': |
|
|
result += f" - geom: GEOMETRY (spatial data)\n" |
|
|
else: |
|
|
result += f" - {col_name}: {col_type}\n" |
|
|
result += "\n" |
|
|
except: |
|
|
pass |
|
|
except Exception as e: |
|
|
logger.error(f"Error getting schemas: {e}") |
|
|
|
|
|
return result |
|
|
|
|
|
def get_table_list(self) -> List[str]: |
|
|
"""Return list of all available table names.""" |
|
|
tables = list(self.BASE_TABLES.keys()) |
|
|
tables.extend(self.layers.values()) |
|
|
return tables |
|
|
|
|
|
def register_layer(self, layer_id: str, geojson: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Registers a GeoJSON object as a table in DuckDB. |
|
|
Returns the table name. |
|
|
""" |
|
|
table_name = f"layer_{layer_id.replace('-', '_')}" |
|
|
|
|
|
|
|
|
self.con.execute(f"DROP TABLE IF EXISTS {table_name}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
def json_serial(obj): |
|
|
"""JSON serializer for objects not serializable by default json code""" |
|
|
if hasattr(obj, 'isoformat'): |
|
|
return obj.isoformat() |
|
|
raise TypeError (f"Type {type(obj)} not serializable") |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: |
|
|
json.dump(geojson, tmp, default=json_serial) |
|
|
tmp_path = tmp.name |
|
|
|
|
|
self.con.execute(f"CREATE TABLE {table_name} AS SELECT * FROM ST_Read('{tmp_path}')") |
|
|
os.unlink(tmp_path) |
|
|
|
|
|
self.layers[layer_id] = table_name |
|
|
logger.info(f"Registered layer {layer_id} as table {table_name}") |
|
|
return table_name |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error registering layer {layer_id}: {e}") |
|
|
raise e |
|
|
|
|
|
def execute_spatial_query(self, sql: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Executes a SQL query and returns the result as a GeoJSON FeatureCollection. |
|
|
Expects the query to return a geometry column. |
|
|
""" |
|
|
try: |
|
|
logger.info(f"Executing Spatial SQL: {sql}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.con.execute(f"CREATE OR REPLACE TEMP TABLE query_result AS {sql}") |
|
|
|
|
|
|
|
|
columns = self.con.execute("DESCRIBE query_result").fetchall() |
|
|
geom_col = next((c[0] for c in columns if c[0] in ['geom', 'geometry']), None) |
|
|
|
|
|
if not geom_col and 'geometry' not in [c[0] for c in columns]: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
other_cols = [c[0] for c in columns if c[0] != geom_col] |
|
|
other_cols_select = ", ".join(other_cols) if other_cols else "" |
|
|
|
|
|
select_clause = f"ST_AsGeoJSON({geom_col})" |
|
|
if other_cols_select: |
|
|
select_clause += f", {other_cols_select}" |
|
|
|
|
|
rows = self.con.execute(f"SELECT {select_clause} FROM query_result").fetchall() |
|
|
|
|
|
features = [] |
|
|
for row in rows: |
|
|
geometry = json.loads(row[0]) |
|
|
properties = {} |
|
|
for i, col_name in enumerate(other_cols): |
|
|
properties[col_name] = row[i+1] |
|
|
|
|
|
features.append({ |
|
|
"type": "Feature", |
|
|
"geometry": geometry, |
|
|
"properties": properties |
|
|
}) |
|
|
|
|
|
return { |
|
|
"type": "FeatureCollection", |
|
|
"features": features, |
|
|
"properties": {} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Spatial query failed: {e}") |
|
|
raise e |
|
|
|
|
|
def get_table_name(self, layer_id: str) -> Optional[str]: |
|
|
return self.layers.get(layer_id) |
|
|
|
|
|
_geo_engine = None |
|
|
|
|
|
def get_geo_engine() -> GeoEngine: |
|
|
global _geo_engine |
|
|
if _geo_engine is None: |
|
|
_geo_engine = GeoEngine() |
|
|
return _geo_engine |
|
|
|