from sqlalchemy import create_engine, Column, String, Integer, Text, DateTime, ForeignKey from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, relationship from datetime import datetime import uuid import os # Neon DB connection DATABASE_URL = os.getenv("DATABASE_URL") if not DATABASE_URL: print("Warning: DATABASE_URL not set. Using sqlite fallback for local dev only.") DATABASE_URL = "sqlite:///./local_chat.db" # Handle "postgres://" to "postgresql://" fix if DATABASE_URL and DATABASE_URL.startswith("postgres://"): DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://", 1) # Create engine # sslmode is usually handled in the connection string itself engine = create_engine(DATABASE_URL, pool_pre_ping=True) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() # Models class Session(Base): __tablename__ = "sessions" id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) user_id = Column(String, nullable=False, index=True) name = Column(String, nullable=False) created_at = Column(DateTime, default=datetime.utcnow) messages = relationship("Message", back_populates="session", cascade="all, delete-orphan") class Message(Base): __tablename__ = "messages" id = Column(Integer, primary_key=True, autoincrement=True) session_id = Column(String, ForeignKey("sessions.id"), nullable=False) role = Column(String, nullable=False) content = Column(Text, nullable=False) created_at = Column(DateTime, default=datetime.utcnow) session = relationship("Session", back_populates="messages") # Database functions def init_db(): """Create all tables""" Base.metadata.create_all(bind=engine) def get_db(): """Get database session""" db = SessionLocal() try: return db finally: db.close() def create_session(user_id: str, name: str = "New Chat"): """Create a new chat session""" db = SessionLocal() try: new_session = Session( id=str(uuid.uuid4()), user_id=user_id, name=name ) db.add(new_session) db.commit() db.refresh(new_session) return { "id": new_session.id, "user_id": new_session.user_id, "name": new_session.name, "created_at": new_session.created_at.isoformat() if new_session.created_at else None } finally: db.close() def get_sessions(user_id: str): """Get all sessions for a user""" db = SessionLocal() try: sessions = db.query(Session).filter(Session.user_id == user_id).order_by(Session.created_at.desc()).all() return [ { "id": s.id, "user_id": s.user_id, "name": s.name, "created_at": s.created_at.isoformat() if s.created_at else None } for s in sessions ] finally: db.close() def add_message(session_id: str, role: str, content: str): """Add a message to a session""" db = SessionLocal() try: message = Message( session_id=session_id, role=role, content=content ) db.add(message) db.commit() finally: db.close() def get_messages(session_id: str): """Get all messages for a session""" db = SessionLocal() try: messages = db.query(Message).filter(Message.session_id == session_id).order_by(Message.created_at.asc()).all() return [ { "id": m.id, "session_id": m.session_id, "role": m.role, "content": m.content, "created_at": m.created_at.isoformat() if m.created_at else None } for m in messages ] finally: db.close() def delete_session(session_id: str): """Delete a session and all its messages""" db = SessionLocal() try: session = db.query(Session).filter(Session.id == session_id).first() if session: db.delete(session) db.commit() finally: db.close()