Spaces:
Sleeping
Sleeping
deploy at 2024-08-24 17:59:30.546351
Browse files
main.py
CHANGED
|
@@ -58,6 +58,7 @@ import tempfile
|
|
| 58 |
from enum import Enum
|
| 59 |
from typing import Tuple as T
|
| 60 |
from urllib.parse import quote
|
|
|
|
| 61 |
|
| 62 |
DEV_MODE = False
|
| 63 |
|
|
@@ -165,24 +166,28 @@ class XFrameOptionsMiddleware(BaseHTTPMiddleware):
|
|
| 165 |
response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
|
| 166 |
return response
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
middlewares = [
|
| 170 |
Middleware(
|
| 171 |
SessionMiddleware,
|
| 172 |
secret_key=get_key(fname=sess_key_path),
|
| 173 |
max_age=3600,
|
| 174 |
-
same_site='None',
|
| 175 |
-
https_only=True,
|
| 176 |
-
domain=".hf.space"
|
| 177 |
),
|
| 178 |
Middleware(XFrameOptionsMiddleware),
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
allow_origins=["*"], # Allows all origins
|
| 182 |
-
allow_credentials=True,
|
| 183 |
-
allow_methods=["*"], # Allows all methods
|
| 184 |
-
allow_headers=["*"], # Allows all headers
|
| 185 |
-
),
|
| 186 |
]
|
| 187 |
bware = Beforeware(
|
| 188 |
user_auth_before,
|
|
@@ -306,7 +311,6 @@ def get(sess):
|
|
| 306 |
queries = [
|
| 307 |
"Breast Cancer Cells Feed on Cholesterol",
|
| 308 |
"Treating Asthma With Plants vs. Pills",
|
| 309 |
-
"Alkylphenol Endocrine Disruptors",
|
| 310 |
"Testing Turmeric on Smokers",
|
| 311 |
"The Role of Pesticides in Parkinson's Disease",
|
| 312 |
]
|
|
@@ -435,9 +439,10 @@ def post(login: Login, sess):
|
|
| 435 |
if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
|
| 436 |
# Incorrect password - add error message
|
| 437 |
return RedirectResponse("/login?error=True", status_code=303)
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
|
|
|
| 441 |
|
| 442 |
|
| 443 |
@app.get("/logout")
|
|
@@ -463,9 +468,26 @@ def replace_hi_with_strong(text):
|
|
| 463 |
|
| 464 |
|
| 465 |
def log_query_to_db(query, ranking, sess):
|
| 466 |
-
|
| 467 |
Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
|
| 468 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
|
| 470 |
|
| 471 |
def parse_results(records):
|
|
@@ -555,12 +577,7 @@ def get_yql(ranking: RankProfile, userquery: str) -> T[str, dict]:
|
|
| 555 |
@app.get("/search")
|
| 556 |
async def search(userquery: str, ranking: str, sess):
|
| 557 |
print(sess)
|
| 558 |
-
if "queries" not in sess:
|
| 559 |
-
sess["queries"] = []
|
| 560 |
quoted = quote(userquery) + "&ranking=" + ranking
|
| 561 |
-
sess["queries"].append(quoted)
|
| 562 |
-
print(f"Searching for: {userquery}")
|
| 563 |
-
print(f"Ranking: {ranking}")
|
| 564 |
log_query_to_db(userquery, ranking, sess)
|
| 565 |
yql, body = get_yql(ranking, userquery)
|
| 566 |
async with vespa_app.asyncio() as session:
|
|
@@ -817,12 +834,13 @@ def get_document(docid: str, sess):
|
|
| 817 |
resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
|
| 818 |
doc = resp.json
|
| 819 |
# Link with Back to search results at top of page
|
|
|
|
| 820 |
return Main(
|
| 821 |
Div(
|
| 822 |
A(
|
| 823 |
I(cls="fa fa-arrow-left"),
|
| 824 |
"Back to search results",
|
| 825 |
-
hx_get=f"/search?userquery={
|
| 826 |
hx_target="#results",
|
| 827 |
style="margin: 10px;",
|
| 828 |
),
|
|
|
|
| 58 |
from enum import Enum
|
| 59 |
from typing import Tuple as T
|
| 60 |
from urllib.parse import quote
|
| 61 |
+
import uuid
|
| 62 |
|
| 63 |
DEV_MODE = False
|
| 64 |
|
|
|
|
| 166 |
response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
|
| 167 |
return response
|
| 168 |
|
| 169 |
+
class SessionLoggingMiddleware(BaseHTTPMiddleware):
|
| 170 |
+
async def dispatch(self, request, call_next):
|
| 171 |
+
print(f"Before request: Session data: {request.session}")
|
| 172 |
+
response = await call_next(request)
|
| 173 |
+
print(f"After request: Session data: {request.session}")
|
| 174 |
+
return response
|
| 175 |
+
|
| 176 |
+
class DebugSessionMiddleware(SessionMiddleware):
|
| 177 |
+
async def __call__(self, scope, receive, send):
|
| 178 |
+
print(f"DebugSessionMiddleware: Before processing - Scope: {scope}")
|
| 179 |
+
await super().__call__(scope, receive, send)
|
| 180 |
+
print(f"DebugSessionMiddleware: After processing - Scope: {scope}")
|
| 181 |
|
| 182 |
middlewares = [
|
| 183 |
Middleware(
|
| 184 |
SessionMiddleware,
|
| 185 |
secret_key=get_key(fname=sess_key_path),
|
| 186 |
max_age=3600,
|
|
|
|
|
|
|
|
|
|
| 187 |
),
|
| 188 |
Middleware(XFrameOptionsMiddleware),
|
| 189 |
+
#Middleware(SessionLoggingMiddleware),
|
| 190 |
+
#Middleware(DebugSessionMiddleware, secret_key=get_key(fname=sess_key_path)),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
]
|
| 192 |
bware = Beforeware(
|
| 193 |
user_auth_before,
|
|
|
|
| 311 |
queries = [
|
| 312 |
"Breast Cancer Cells Feed on Cholesterol",
|
| 313 |
"Treating Asthma With Plants vs. Pills",
|
|
|
|
| 314 |
"Testing Turmeric on Smokers",
|
| 315 |
"The Role of Pesticides in Parkinson's Disease",
|
| 316 |
]
|
|
|
|
| 439 |
if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
|
| 440 |
# Incorrect password - add error message
|
| 441 |
return RedirectResponse("/login?error=True", status_code=303)
|
| 442 |
+
print(f"Session after setting auth: {sess}")
|
| 443 |
+
response = RedirectResponse("/admin", status_code=303)
|
| 444 |
+
print(f"Cookies being set: {response.headers.get('Set-Cookie')}")
|
| 445 |
+
return response
|
| 446 |
|
| 447 |
|
| 448 |
@app.get("/logout")
|
|
|
|
| 468 |
|
| 469 |
|
| 470 |
def log_query_to_db(query, ranking, sess):
|
| 471 |
+
queries.insert(
|
| 472 |
Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
|
| 473 |
)
|
| 474 |
+
if 'user_id' not in sess:
|
| 475 |
+
sess['user_id'] = str(uuid.uuid4())
|
| 476 |
+
|
| 477 |
+
if 'queries' not in sess:
|
| 478 |
+
sess['queries'] = []
|
| 479 |
+
|
| 480 |
+
query_data = {
|
| 481 |
+
'query': query,
|
| 482 |
+
'ranking': ranking,
|
| 483 |
+
'timestamp': int(time.time())
|
| 484 |
+
}
|
| 485 |
+
sess['queries'].append(query_data)
|
| 486 |
+
|
| 487 |
+
# Limit the number of queries stored in the session to prevent it from growing too large
|
| 488 |
+
sess['queries'] = sess['queries'][-100:] # Keep only the last 100 queries
|
| 489 |
+
|
| 490 |
+
return query_data
|
| 491 |
|
| 492 |
|
| 493 |
def parse_results(records):
|
|
|
|
| 577 |
@app.get("/search")
|
| 578 |
async def search(userquery: str, ranking: str, sess):
|
| 579 |
print(sess)
|
|
|
|
|
|
|
| 580 |
quoted = quote(userquery) + "&ranking=" + ranking
|
|
|
|
|
|
|
|
|
|
| 581 |
log_query_to_db(userquery, ranking, sess)
|
| 582 |
yql, body = get_yql(ranking, userquery)
|
| 583 |
async with vespa_app.asyncio() as session:
|
|
|
|
| 834 |
resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
|
| 835 |
doc = resp.json
|
| 836 |
# Link with Back to search results at top of page
|
| 837 |
+
last_query = sess.get('queries', [{}])[-1].get('query', '')
|
| 838 |
return Main(
|
| 839 |
Div(
|
| 840 |
A(
|
| 841 |
I(cls="fa fa-arrow-left"),
|
| 842 |
"Back to search results",
|
| 843 |
+
hx_get=f"/search?userquery={last_query}",
|
| 844 |
hx_target="#results",
|
| 845 |
style="margin: 10px;",
|
| 846 |
),
|