Nolwenn
commited on
Commit
·
dbb98e1
1
Parent(s):
7aa20fe
Change to docker space
Browse files- Dockerfile +19 -0
- README.md +3 -4
- crs_arena/arena.py +6 -17
- crs_arena/utils.py +6 -49
- download_external_data.py +56 -0
- requirements.txt +1 -1
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile to run CRS Arena
|
| 2 |
+
FROM python:3.9-bullseye
|
| 3 |
+
|
| 4 |
+
COPY . .
|
| 5 |
+
|
| 6 |
+
# Install requirements
|
| 7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
+
|
| 9 |
+
# Expose Hugging Face Space secrets to environment variables
|
| 10 |
+
RUN --mount=type=secret,id=models_folder_url,mode=0444,required=true echo "MODELS_FOLDER_URL=$(cat /run/secrets/models_folder_url)" >> .env
|
| 11 |
+
RUN --mount=type=secret,id=item_embeddings_url,mode=0444,required=true echo "ITEM_EMBEDDINGS_URL=$(cat /run/secrets/item_embeddings_url)" >> .env
|
| 12 |
+
|
| 13 |
+
# Download external data
|
| 14 |
+
RUN python download_external_data.py
|
| 15 |
+
|
| 16 |
+
EXPOSE 7860
|
| 17 |
+
|
| 18 |
+
# Run Streamlit app
|
| 19 |
+
CMD ["python", "-m", "streamlit", "run", "crs_arena.arena", "--server.port", "7860"]
|
README.md
CHANGED
|
@@ -3,11 +3,10 @@ title: CRSArena
|
|
| 3 |
emoji: 🐠
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: yellow
|
| 6 |
-
sdk:
|
| 7 |
-
|
| 8 |
-
app_file: crs_arena/arena.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
-
Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
|
|
|
|
| 3 |
emoji: 🐠
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: yellow
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
|
|
|
| 8 |
pinned: false
|
| 9 |
license: mit
|
| 10 |
---
|
| 11 |
|
| 12 |
+
Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
|
crs_arena/arena.py
CHANGED
|
@@ -22,8 +22,6 @@ import asyncio
|
|
| 22 |
import json
|
| 23 |
import logging
|
| 24 |
import os
|
| 25 |
-
import threading
|
| 26 |
-
import time
|
| 27 |
from copy import deepcopy
|
| 28 |
from datetime import datetime
|
| 29 |
from typing import Dict, List
|
|
@@ -37,12 +35,7 @@ from battle_manager import (
|
|
| 37 |
)
|
| 38 |
from crs_fighter import CRSFighter
|
| 39 |
from streamlit_lottie import st_lottie_spinner
|
| 40 |
-
from utils import
|
| 41 |
-
download_and_extract_item_embeddings,
|
| 42 |
-
download_and_extract_models,
|
| 43 |
-
upload_conversation_logs_to_hf,
|
| 44 |
-
upload_feedback_to_gsheet,
|
| 45 |
-
)
|
| 46 |
|
| 47 |
from src.model.crb_crs.recommender import *
|
| 48 |
|
|
@@ -56,14 +49,6 @@ logging.basicConfig(
|
|
| 56 |
logger = logging.getLogger(__name__)
|
| 57 |
logger.setLevel(logging.INFO)
|
| 58 |
|
| 59 |
-
# Download models and data externally stored if not already downloaded
|
| 60 |
-
if not os.path.exists("data/models"):
|
| 61 |
-
logger.info("Downloading models...")
|
| 62 |
-
download_and_extract_models()
|
| 63 |
-
if not os.path.exists("data/embed_items"):
|
| 64 |
-
logger.info("Downloading item embeddings...")
|
| 65 |
-
download_and_extract_item_embeddings()
|
| 66 |
-
|
| 67 |
# Create the conversation logs directory
|
| 68 |
CONVERSATION_LOG_DIR = "data/arena/conversation_logs/"
|
| 69 |
os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True)
|
|
@@ -82,7 +67,9 @@ def record_vote(vote: str) -> None:
|
|
| 82 |
crs1_model: CRSFighter = st.session_state["crs1"]
|
| 83 |
crs2_model: CRSFighter = st.session_state["crs2"]
|
| 84 |
last_row_id = str(datetime.now())
|
| 85 |
-
logger.info(
|
|
|
|
|
|
|
| 86 |
asyncio.run(
|
| 87 |
upload_feedback_to_gsheet(
|
| 88 |
{
|
|
@@ -189,6 +176,7 @@ def get_crs_response(crs: CRSFighter, message: str) -> str:
|
|
| 189 |
# time.sleep(0.05)
|
| 190 |
return response
|
| 191 |
|
|
|
|
| 192 |
@st.dialog("Your vote has been submitted! Thank you!")
|
| 193 |
def feedback_dialog(row_id: int) -> None:
|
| 194 |
"""Pop-up dialog to provide feedback after voting.
|
|
@@ -208,6 +196,7 @@ def feedback_dialog(row_id: int) -> None:
|
|
| 208 |
st.session_state.clear()
|
| 209 |
st.rerun()
|
| 210 |
|
|
|
|
| 211 |
@st.fragment
|
| 212 |
def chat_col(crs_id: int, color: str):
|
| 213 |
"""Chat column for the CRS model.
|
|
|
|
| 22 |
import json
|
| 23 |
import logging
|
| 24 |
import os
|
|
|
|
|
|
|
| 25 |
from copy import deepcopy
|
| 26 |
from datetime import datetime
|
| 27 |
from typing import Dict, List
|
|
|
|
| 35 |
)
|
| 36 |
from crs_fighter import CRSFighter
|
| 37 |
from streamlit_lottie import st_lottie_spinner
|
| 38 |
+
from utils import upload_conversation_logs_to_hf, upload_feedback_to_gsheet
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
from src.model.crb_crs.recommender import *
|
| 41 |
|
|
|
|
| 49 |
logger = logging.getLogger(__name__)
|
| 50 |
logger.setLevel(logging.INFO)
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# Create the conversation logs directory
|
| 53 |
CONVERSATION_LOG_DIR = "data/arena/conversation_logs/"
|
| 54 |
os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True)
|
|
|
|
| 67 |
crs1_model: CRSFighter = st.session_state["crs1"]
|
| 68 |
crs2_model: CRSFighter = st.session_state["crs2"]
|
| 69 |
last_row_id = str(datetime.now())
|
| 70 |
+
logger.info(
|
| 71 |
+
f"Vote: {last_row_id}, {user_id}, {crs1_model.name}, {crs2_model.name}, {vote}"
|
| 72 |
+
)
|
| 73 |
asyncio.run(
|
| 74 |
upload_feedback_to_gsheet(
|
| 75 |
{
|
|
|
|
| 176 |
# time.sleep(0.05)
|
| 177 |
return response
|
| 178 |
|
| 179 |
+
|
| 180 |
@st.dialog("Your vote has been submitted! Thank you!")
|
| 181 |
def feedback_dialog(row_id: int) -> None:
|
| 182 |
"""Pop-up dialog to provide feedback after voting.
|
|
|
|
| 196 |
st.session_state.clear()
|
| 197 |
st.rerun()
|
| 198 |
|
| 199 |
+
|
| 200 |
@st.fragment
|
| 201 |
def chat_col(crs_id: int, color: str):
|
| 202 |
"""Chat column for the CRS model.
|
crs_arena/utils.py
CHANGED
|
@@ -4,16 +4,13 @@ import ast
|
|
| 4 |
import asyncio
|
| 5 |
import logging
|
| 6 |
import os
|
| 7 |
-
import sqlite3
|
| 8 |
import sys
|
| 9 |
-
import tarfile
|
| 10 |
from datetime import timedelta
|
| 11 |
-
from typing import
|
| 12 |
|
| 13 |
import openai
|
| 14 |
import pandas as pd
|
| 15 |
import streamlit as st
|
| 16 |
-
import wget
|
| 17 |
import yaml
|
| 18 |
from huggingface_hub import HfApi
|
| 19 |
from streamlit_gsheets.gsheets_connection import GSheetsServiceAccountClient
|
|
@@ -23,7 +20,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
| 23 |
from src.model.crs_model import CRSModel
|
| 24 |
|
| 25 |
# Initialize Hugging Face API
|
| 26 |
-
HF_API = HfApi(token=
|
| 27 |
|
| 28 |
|
| 29 |
@st.cache_resource(
|
|
@@ -52,52 +49,12 @@ def get_crs_model(model_name: str, model_config_file: str) -> CRSModel:
|
|
| 52 |
model_args = yaml.safe_load(open(model_config_file, "r"))
|
| 53 |
|
| 54 |
if "chatgpt" in model_name:
|
| 55 |
-
openai.api_key =
|
| 56 |
|
| 57 |
# Extract crs model from name
|
| 58 |
name = model_name.split("_")[0]
|
| 59 |
|
| 60 |
-
return CRSModel(name, **model_args)
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def download_and_extract_models() -> None:
|
| 65 |
-
"""Downloads the models folder from the server and extracts it."""
|
| 66 |
-
logging.debug("Downloading models folder.")
|
| 67 |
-
models_url = st.secrets["models_folder_url"]
|
| 68 |
-
models_targz = "models.tar.gz"
|
| 69 |
-
models_folder = "data/models/"
|
| 70 |
-
try:
|
| 71 |
-
wget.download(models_url, models_targz)
|
| 72 |
-
|
| 73 |
-
logging.debug("Extracting models folder.")
|
| 74 |
-
with tarfile.open(models_targz, "r:gz") as tar:
|
| 75 |
-
tar.extractall(models_folder)
|
| 76 |
-
|
| 77 |
-
os.remove(models_targz)
|
| 78 |
-
logging.debug("Models folder downloaded and extracted.")
|
| 79 |
-
except Exception as e:
|
| 80 |
-
logging.error(f"Error downloading models folder: {e}")
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def download_and_extract_item_embeddings() -> None:
|
| 84 |
-
"""Downloads the item embeddings folder from the server and extracts it."""
|
| 85 |
-
logging.debug("Downloading item embeddings folder.")
|
| 86 |
-
item_embeddings_url = st.secrets["item_embeddings_url"]
|
| 87 |
-
item_embeddings_tarbz = "item_embeddings.tar.bz2"
|
| 88 |
-
item_embeddings_folder = "data/"
|
| 89 |
-
|
| 90 |
-
try:
|
| 91 |
-
wget.download(item_embeddings_url, item_embeddings_tarbz)
|
| 92 |
-
|
| 93 |
-
logging.debug("Extracting item embeddings folder.")
|
| 94 |
-
with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar:
|
| 95 |
-
tar.extractall(item_embeddings_folder)
|
| 96 |
-
|
| 97 |
-
os.remove(item_embeddings_tarbz)
|
| 98 |
-
logging.debug("Item embeddings folder downloaded and extracted.")
|
| 99 |
-
except Exception as e:
|
| 100 |
-
logging.error(f"Error downloading item embeddings folder: {e}")
|
| 101 |
|
| 102 |
|
| 103 |
async def upload_conversation_logs_to_hf(
|
|
@@ -122,7 +79,7 @@ async def upload_conversation_logs_to_hf(
|
|
| 122 |
lambda: HF_API.upload_file(
|
| 123 |
path_or_fileobj=conversation_log_file_path,
|
| 124 |
path_in_repo=repo_filename,
|
| 125 |
-
repo_id=
|
| 126 |
repo_type="dataset",
|
| 127 |
),
|
| 128 |
)
|
|
@@ -164,7 +121,7 @@ def _upload_feedback_to_gsheet_sync(
|
|
| 164 |
worksheet: Name of the worksheet to upload the feedback to.
|
| 165 |
"""
|
| 166 |
gs_connection = GSheetsServiceAccountClient(
|
| 167 |
-
ast.literal_eval(
|
| 168 |
)
|
| 169 |
df = gs_connection.read(worksheet=worksheet)
|
| 170 |
if df[df["id"] == row["id"]].empty:
|
|
|
|
| 4 |
import asyncio
|
| 5 |
import logging
|
| 6 |
import os
|
|
|
|
| 7 |
import sys
|
|
|
|
| 8 |
from datetime import timedelta
|
| 9 |
+
from typing import Dict
|
| 10 |
|
| 11 |
import openai
|
| 12 |
import pandas as pd
|
| 13 |
import streamlit as st
|
|
|
|
| 14 |
import yaml
|
| 15 |
from huggingface_hub import HfApi
|
| 16 |
from streamlit_gsheets.gsheets_connection import GSheetsServiceAccountClient
|
|
|
|
| 20 |
from src.model.crs_model import CRSModel
|
| 21 |
|
| 22 |
# Initialize Hugging Face API
|
| 23 |
+
HF_API = HfApi(token=os.environ.get("hf_token"))
|
| 24 |
|
| 25 |
|
| 26 |
@st.cache_resource(
|
|
|
|
| 49 |
model_args = yaml.safe_load(open(model_config_file, "r"))
|
| 50 |
|
| 51 |
if "chatgpt" in model_name:
|
| 52 |
+
openai.api_key = os.environ.get("openai_api_key")
|
| 53 |
|
| 54 |
# Extract crs model from name
|
| 55 |
name = model_name.split("_")[0]
|
| 56 |
|
| 57 |
+
return CRSModel(name, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
async def upload_conversation_logs_to_hf(
|
|
|
|
| 79 |
lambda: HF_API.upload_file(
|
| 80 |
path_or_fileobj=conversation_log_file_path,
|
| 81 |
path_in_repo=repo_filename,
|
| 82 |
+
repo_id=os.environ.get("dataset_repo"),
|
| 83 |
repo_type="dataset",
|
| 84 |
),
|
| 85 |
)
|
|
|
|
| 121 |
worksheet: Name of the worksheet to upload the feedback to.
|
| 122 |
"""
|
| 123 |
gs_connection = GSheetsServiceAccountClient(
|
| 124 |
+
ast.literal_eval(os.environ.get("gsheet"))
|
| 125 |
)
|
| 126 |
df = gs_connection.read(worksheet=worksheet)
|
| 127 |
if df[df["id"] == row["id"]].empty:
|
download_external_data.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Script to download external data for the project at build time."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import tarfile
|
| 6 |
+
|
| 7 |
+
import wget
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def download_and_extract_models() -> None:
|
| 11 |
+
"""Downloads the models folder from the server and extracts it."""
|
| 12 |
+
logging.debug("Downloading models folder.")
|
| 13 |
+
models_url = os.environ.get("MODELS_FOLDER_URL")
|
| 14 |
+
models_targz = "models.tar.gz"
|
| 15 |
+
models_folder = "data/models/"
|
| 16 |
+
try:
|
| 17 |
+
wget.download(models_url, models_targz)
|
| 18 |
+
|
| 19 |
+
logging.debug("Extracting models folder.")
|
| 20 |
+
with tarfile.open(models_targz, "r:gz") as tar:
|
| 21 |
+
tar.extractall(models_folder)
|
| 22 |
+
|
| 23 |
+
os.remove(models_targz)
|
| 24 |
+
logging.debug("Models folder downloaded and extracted.")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
logging.error(f"Error downloading models folder: {e}")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def download_and_extract_item_embeddings() -> None:
|
| 30 |
+
"""Downloads the item embeddings folder from the server and extracts it."""
|
| 31 |
+
logging.debug("Downloading item embeddings folder.")
|
| 32 |
+
item_embeddings_url = os.environ.get("ITEM_EMBEDDINGS_URL")
|
| 33 |
+
item_embeddings_tarbz = "item_embeddings.tar.bz2"
|
| 34 |
+
item_embeddings_folder = "data/"
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
wget.download(item_embeddings_url, item_embeddings_tarbz)
|
| 38 |
+
|
| 39 |
+
logging.debug("Extracting item embeddings folder.")
|
| 40 |
+
with tarfile.open(item_embeddings_tarbz, "r:bz2") as tar:
|
| 41 |
+
tar.extractall(item_embeddings_folder)
|
| 42 |
+
|
| 43 |
+
os.remove(item_embeddings_tarbz)
|
| 44 |
+
logging.debug("Item embeddings folder downloaded and extracted.")
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logging.error(f"Error downloading item embeddings folder: {e}")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
if not os.path.exists("data/models"):
|
| 51 |
+
logging.info("Downloading models...")
|
| 52 |
+
download_and_extract_models()
|
| 53 |
+
|
| 54 |
+
if not os.path.exists("data/embed_items"):
|
| 55 |
+
logging.info("Downloading item embeddings...")
|
| 56 |
+
download_and_extract_item_embeddings()
|
requirements.txt
CHANGED
|
@@ -10,7 +10,7 @@ tiktoken==0.7.0
|
|
| 10 |
tenacity<9.0.0
|
| 11 |
thefuzz==0.22.1
|
| 12 |
numpy<2
|
| 13 |
-
streamlit==1.
|
| 14 |
SQLAlchemy==1.4.0
|
| 15 |
sent2vec==0.3.0
|
| 16 |
wget==3.2
|
|
|
|
| 10 |
tenacity<9.0.0
|
| 11 |
thefuzz==0.22.1
|
| 12 |
numpy<2
|
| 13 |
+
streamlit==1.38.0
|
| 14 |
SQLAlchemy==1.4.0
|
| 15 |
sent2vec==0.3.0
|
| 16 |
wget==3.2
|