VLM-Lens / src /utils.py
marstin's picture
[martin-dev] add demo v1 test
d425e71
raw
history blame
1.64 kB
"""Utility functions for interacting with the SQLite database."""
import io
import logging
import sqlite3
from typing import Any, List, Optional
import torch
def select_tensors(
db_path: str,
table_name: str,
keys: List[str] = ['layer', 'pooling_method', 'tensor_dim', 'tensor'],
sql_where: Optional[str] = None,
) -> List[Any]:
"""Select and return all tensors from the specified SQLite database and table.
Args:
db_path (str): Path to the SQLite database file.
table_name (str): Name of the table to query.
keys (List[str]): List of keys to select from the database.
sql_where (str): Optional SQL WHERE clause to filter results.
Returns:
List[Any]: A list of tensors retrieved from the database.
"""
if 'tensor' not in keys:
logging.warning("'tensor' key should be included to retrieve tensors; automatically adding it.")
keys.append('tensor')
final_results = []
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
query = f'SELECT {", ".join(keys)} FROM {table_name}'
if sql_where:
assert sql_where.strip().lower().startswith('where'), "sql_where should start with 'WHERE'"
query += f' {sql_where}'
cursor.execute(query)
results = cursor.fetchall()
for row in results:
result_item = {key: value for key, value in zip(keys, row)}
result_item['tensor'] = torch.load(io.BytesIO(result_item['tensor']), map_location='cpu')
final_results.append(result_item)
return final_results