Spaces:
Running
Running
File size: 4,258 Bytes
987982d 4832b3b 192c089 32f27f1 4832b3b 192c089 4832b3b 192c089 4832b3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from sqlalchemy import create_engine, Column, String, Integer, Text, DateTime, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from datetime import datetime
import uuid
import os
# Neon DB connection
DATABASE_URL = os.getenv("DATABASE_URL")
if not DATABASE_URL:
print("Warning: DATABASE_URL not set. Using sqlite fallback for local dev only.")
DATABASE_URL = "sqlite:///./local_chat.db"
# Handle "postgres://" to "postgresql://" fix
if DATABASE_URL and DATABASE_URL.startswith("postgres://"):
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://", 1)
# Create engine
# sslmode is usually handled in the connection string itself
engine = create_engine(DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Models
class Session(Base):
__tablename__ = "sessions"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(String, nullable=False, index=True)
name = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
messages = relationship("Message", back_populates="session", cascade="all, delete-orphan")
class Message(Base):
__tablename__ = "messages"
id = Column(Integer, primary_key=True, autoincrement=True)
session_id = Column(String, ForeignKey("sessions.id"), nullable=False)
role = Column(String, nullable=False)
content = Column(Text, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
session = relationship("Session", back_populates="messages")
# Database functions
def init_db():
"""Create all tables"""
Base.metadata.create_all(bind=engine)
def get_db():
"""Get database session"""
db = SessionLocal()
try:
return db
finally:
db.close()
def create_session(user_id: str, name: str = "New Chat"):
"""Create a new chat session"""
db = SessionLocal()
try:
new_session = Session(
id=str(uuid.uuid4()),
user_id=user_id,
name=name
)
db.add(new_session)
db.commit()
db.refresh(new_session)
return {
"id": new_session.id,
"user_id": new_session.user_id,
"name": new_session.name,
"created_at": new_session.created_at.isoformat() if new_session.created_at else None
}
finally:
db.close()
def get_sessions(user_id: str):
"""Get all sessions for a user"""
db = SessionLocal()
try:
sessions = db.query(Session).filter(Session.user_id == user_id).order_by(Session.created_at.desc()).all()
return [
{
"id": s.id,
"user_id": s.user_id,
"name": s.name,
"created_at": s.created_at.isoformat() if s.created_at else None
}
for s in sessions
]
finally:
db.close()
def add_message(session_id: str, role: str, content: str):
"""Add a message to a session"""
db = SessionLocal()
try:
message = Message(
session_id=session_id,
role=role,
content=content
)
db.add(message)
db.commit()
finally:
db.close()
def get_messages(session_id: str):
"""Get all messages for a session"""
db = SessionLocal()
try:
messages = db.query(Message).filter(Message.session_id == session_id).order_by(Message.created_at.asc()).all()
return [
{
"id": m.id,
"session_id": m.session_id,
"role": m.role,
"content": m.content,
"created_at": m.created_at.isoformat() if m.created_at else None
}
for m in messages
]
finally:
db.close()
def delete_session(session_id: str):
"""Delete a session and all its messages"""
db = SessionLocal()
try:
session = db.query(Session).filter(Session.id == session_id).first()
if session:
db.delete(session)
db.commit()
finally:
db.close()
|