jeevzz commited on
Commit
192c089
·
verified ·
1 Parent(s): 987982d

Update database.py

Browse files
Files changed (1) hide show
  1. database.py +13 -12
database.py CHANGED
@@ -5,13 +5,20 @@ from datetime import datetime
5
  import uuid
6
  import os
7
 
8
- # Use SQLite for reliable deployment on HF Spaces
9
- # The /data directory is persistent in HF Spaces
10
- DATABASE_URL = "sqlite:////data/chat_history.db"
 
 
 
 
 
 
 
11
 
12
  # Create engine
13
- # check_same_thread=False is needed for SQLite with FastAPI
14
- engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
15
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
16
  Base = declarative_base()
17
 
@@ -40,12 +47,6 @@ class Message(Base):
40
  # Database functions
41
  def init_db():
42
  """Create all tables"""
43
- # Ensure /data directory exists (for HF Spaces)
44
- if DATABASE_URL.startswith("sqlite:////data"):
45
- os.makedirs("/data", exist_ok=True)
46
- elif DATABASE_URL.startswith("sqlite:///./data"):
47
- os.makedirs("./data", exist_ok=True)
48
-
49
  Base.metadata.create_all(bind=engine)
50
 
51
  def get_db():
@@ -54,7 +55,7 @@ def get_db():
54
  try:
55
  return db
56
  finally:
57
- pass
58
 
59
  def create_session(user_id: str, name: str = "New Chat"):
60
  """Create a new chat session"""
 
5
  import uuid
6
  import os
7
 
8
+ # Neon DB connection
9
+ DATABASE_URL = os.getenv("DATABASE_URL")
10
+
11
+ if not DATABASE_URL:
12
+ print("Warning: DATABASE_URL not set. Using sqlite fallback for local dev only.")
13
+ DATABASE_URL = "sqlite:///./local_chat.db"
14
+
15
+ # Handle "postgres://" to "postgresql://" fix
16
+ if DATABASE_URL and DATABASE_URL.startswith("postgres://"):
17
+ DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://", 1)
18
 
19
  # Create engine
20
+ # sslmode is usually handled in the connection string itself
21
+ engine = create_engine(DATABASE_URL, pool_pre_ping=True)
22
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
23
  Base = declarative_base()
24
 
 
47
  # Database functions
48
  def init_db():
49
  """Create all tables"""
 
 
 
 
 
 
50
  Base.metadata.create_all(bind=engine)
51
 
52
  def get_db():
 
55
  try:
56
  return db
57
  finally:
58
+ db.close()
59
 
60
  def create_session(user_id: str, name: str = "New Chat"):
61
  """Create a new chat session"""