Spaces:
Build error
Build error
Update src/chat.py
Browse files- src/chat.py +39 -7
src/chat.py
CHANGED
|
@@ -7,18 +7,16 @@ import re
|
|
| 7 |
from difflib import get_close_matches
|
| 8 |
|
| 9 |
class SchoolChatbot:
|
| 10 |
-
"""
|
| 11 |
-
A chatbot that integrates structured school data and language generation to assist with Boston Public School queries.
|
| 12 |
-
"""
|
| 13 |
|
| 14 |
def __init__(self):
|
| 15 |
model_id = MY_MODEL if MY_MODEL else BASE_MODEL
|
| 16 |
self.client = InferenceClient(model=model_id, token=HF_TOKEN)
|
| 17 |
self.df = pd.read_csv("bps_data.csv")
|
| 18 |
-
with open("
|
| 19 |
self.keyword_map = json.load(f)
|
| 20 |
|
| 21 |
-
#
|
| 22 |
self.school_name_map = {}
|
| 23 |
for _, row in self.df.iterrows():
|
| 24 |
primary = row.get("BPS_School_Name")
|
|
@@ -31,13 +29,12 @@ class SchoolChatbot:
|
|
| 31 |
if pd.notna(abbrev):
|
| 32 |
self.school_name_map[abbrev.lower()] = primary
|
| 33 |
|
| 34 |
-
# Add custom aliases
|
| 35 |
self.school_name_map.update({
|
| 36 |
"acc": "Another Course to College*",
|
| 37 |
"baldwin": "Baldwin Early Learning Pilot Academy",
|
| 38 |
"adams elementary": "Adams, Samuel Elementary",
|
| 39 |
"alighieri montessori": "Alighieri, Dante Montessori School",
|
| 40 |
-
"phineas bates": "Bates, Phineas Elementary"
|
| 41 |
})
|
| 42 |
|
| 43 |
def format_prompt(self, user_input):
|
|
@@ -77,7 +74,42 @@ class SchoolChatbot:
|
|
| 77 |
context_items.append(f"The school's {kw} is {val.lower()}.")
|
| 78 |
return context_items
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def get_response(self, user_input):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
matched_school = self.match_school_name(user_input)
|
| 82 |
structured_facts = self.extract_context_with_keywords(user_input, matched_school)
|
| 83 |
|
|
|
|
| 7 |
from difflib import get_close_matches
|
| 8 |
|
| 9 |
class SchoolChatbot:
|
| 10 |
+
"""Boston School Chatbot integrating structured data, vector context, and model completion."""
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def __init__(self):
|
| 13 |
model_id = MY_MODEL if MY_MODEL else BASE_MODEL
|
| 14 |
self.client = InferenceClient(model=model_id, token=HF_TOKEN)
|
| 15 |
self.df = pd.read_csv("bps_data.csv")
|
| 16 |
+
with open("cleaned_keyword_to_column_map.json") as f:
|
| 17 |
self.keyword_map = json.load(f)
|
| 18 |
|
| 19 |
+
# Build name variants for school matching
|
| 20 |
self.school_name_map = {}
|
| 21 |
for _, row in self.df.iterrows():
|
| 22 |
primary = row.get("BPS_School_Name")
|
|
|
|
| 29 |
if pd.notna(abbrev):
|
| 30 |
self.school_name_map[abbrev.lower()] = primary
|
| 31 |
|
|
|
|
| 32 |
self.school_name_map.update({
|
| 33 |
"acc": "Another Course to College*",
|
| 34 |
"baldwin": "Baldwin Early Learning Pilot Academy",
|
| 35 |
"adams elementary": "Adams, Samuel Elementary",
|
| 36 |
"alighieri montessori": "Alighieri, Dante Montessori School",
|
| 37 |
+
"phineas bates": "Bates, Phineas Elementary",
|
| 38 |
})
|
| 39 |
|
| 40 |
def format_prompt(self, user_input):
|
|
|
|
| 74 |
context_items.append(f"The school's {kw} is {val.lower()}.")
|
| 75 |
return context_items
|
| 76 |
|
| 77 |
+
def query_schools_by_feature(self, query):
|
| 78 |
+
tokens = re.findall(r'\b\w+\b', query.lower())
|
| 79 |
+
matched_keywords = set()
|
| 80 |
+
for token in tokens:
|
| 81 |
+
matched_keywords.update(get_close_matches(token, self.keyword_map.keys(), cutoff=0.85))
|
| 82 |
+
|
| 83 |
+
positive_terms = "yes|accessible|adequate|good|excellent|present"
|
| 84 |
+
negative_terms = "no|not accessible|inadequate|poor|bad|limited"
|
| 85 |
+
|
| 86 |
+
matching_schools = set()
|
| 87 |
+
inverse = any(t in query.lower() for t in ["not", "inaccessible", "bad", "poor", "lacking"])
|
| 88 |
+
|
| 89 |
+
for keyword in matched_keywords:
|
| 90 |
+
col = self.keyword_map.get(keyword)
|
| 91 |
+
if col and col in self.df.columns:
|
| 92 |
+
if inverse:
|
| 93 |
+
subset = self.df[~self.df[col].astype(str).str.lower().str.contains(positive_terms, na=False)]
|
| 94 |
+
else:
|
| 95 |
+
subset = self.df[self.df[col].astype(str).str.lower().str.contains(positive_terms, na=False)]
|
| 96 |
+
schools = subset["BPS_School_Name"].dropna().unique().tolist()
|
| 97 |
+
matching_schools.update(schools)
|
| 98 |
+
|
| 99 |
+
if not matching_schools:
|
| 100 |
+
return None
|
| 101 |
+
return (
|
| 102 |
+
"The following schools match your criteria:\n" +
|
| 103 |
+
"\n".join(f"- {s}" for s in sorted(matching_schools))
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
def get_response(self, user_input):
|
| 107 |
+
# School-wide filter query
|
| 108 |
+
school_filter = self.query_schools_by_feature(user_input)
|
| 109 |
+
if school_filter:
|
| 110 |
+
return school_filter
|
| 111 |
+
|
| 112 |
+
# Per-school context query
|
| 113 |
matched_school = self.match_school_name(user_input)
|
| 114 |
structured_facts = self.extract_context_with_keywords(user_input, matched_school)
|
| 115 |
|