Spaces:
Build error
Build error
Update src/chat.py
Browse files- src/chat.py +30 -7
src/chat.py
CHANGED
|
@@ -15,9 +15,31 @@ class SchoolChatbot:
|
|
| 15 |
model_id = MY_MODEL if MY_MODEL else BASE_MODEL
|
| 16 |
self.client = InferenceClient(model=model_id, token=HF_TOKEN)
|
| 17 |
self.df = pd.read_csv("bps_data.csv")
|
| 18 |
-
with open("
|
| 19 |
self.keyword_map = json.load(f)
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def format_prompt(self, user_input):
|
| 22 |
return (
|
| 23 |
"<|system|>You are a helpful assistant that specializes in Boston public school enrollment.<|end|>\n"
|
|
@@ -25,6 +47,12 @@ class SchoolChatbot:
|
|
| 25 |
"<|assistant|>"
|
| 26 |
)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def extract_context_with_keywords(self, prompt, school_name=None):
|
| 29 |
def extract_keywords(text):
|
| 30 |
tokens = re.findall(r'\b\w+\b', text.lower())
|
|
@@ -50,12 +78,7 @@ class SchoolChatbot:
|
|
| 50 |
return context_items
|
| 51 |
|
| 52 |
def get_response(self, user_input):
|
| 53 |
-
matched_school =
|
| 54 |
-
for name in self.df["BPS_School_Name"].dropna():
|
| 55 |
-
if name.lower() in user_input.lower():
|
| 56 |
-
matched_school = name
|
| 57 |
-
break
|
| 58 |
-
|
| 59 |
structured_facts = self.extract_context_with_keywords(user_input, matched_school)
|
| 60 |
|
| 61 |
if structured_facts:
|
|
|
|
| 15 |
model_id = MY_MODEL if MY_MODEL else BASE_MODEL
|
| 16 |
self.client = InferenceClient(model=model_id, token=HF_TOKEN)
|
| 17 |
self.df = pd.read_csv("bps_data.csv")
|
| 18 |
+
with open("cleaned_keyword_to_column_map.json") as f:
|
| 19 |
self.keyword_map = json.load(f)
|
| 20 |
|
| 21 |
+
# Create school name map with aliases
|
| 22 |
+
self.school_name_map = {}
|
| 23 |
+
for _, row in self.df.iterrows():
|
| 24 |
+
primary = row.get("BPS_School_Name")
|
| 25 |
+
hist = row.get("BPS_Historical_Name")
|
| 26 |
+
abbrev = row.get("SMMA_Abbreviated_Name")
|
| 27 |
+
if pd.notna(primary):
|
| 28 |
+
self.school_name_map[primary.lower()] = primary
|
| 29 |
+
if pd.notna(hist):
|
| 30 |
+
self.school_name_map[hist.lower()] = primary
|
| 31 |
+
if pd.notna(abbrev):
|
| 32 |
+
self.school_name_map[abbrev.lower()] = primary
|
| 33 |
+
|
| 34 |
+
# Add custom aliases
|
| 35 |
+
self.school_name_map.update({
|
| 36 |
+
"acc": "Another Course to College*",
|
| 37 |
+
"baldwin": "Baldwin Early Learning Pilot Academy",
|
| 38 |
+
"adams elementary": "Adams, Samuel Elementary",
|
| 39 |
+
"alighieri montessori": "Alighieri, Dante Montessori School",
|
| 40 |
+
"phineas bates": "Bates, Phineas Elementary"
|
| 41 |
+
})
|
| 42 |
+
|
| 43 |
def format_prompt(self, user_input):
|
| 44 |
return (
|
| 45 |
"<|system|>You are a helpful assistant that specializes in Boston public school enrollment.<|end|>\n"
|
|
|
|
| 47 |
"<|assistant|>"
|
| 48 |
)
|
| 49 |
|
| 50 |
+
def match_school_name(self, query):
|
| 51 |
+
for key in self.school_name_map:
|
| 52 |
+
if key in query.lower():
|
| 53 |
+
return self.school_name_map[key]
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
def extract_context_with_keywords(self, prompt, school_name=None):
|
| 57 |
def extract_keywords(text):
|
| 58 |
tokens = re.findall(r'\b\w+\b', text.lower())
|
|
|
|
| 78 |
return context_items
|
| 79 |
|
| 80 |
def get_response(self, user_input):
|
| 81 |
+
matched_school = self.match_school_name(user_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
structured_facts = self.extract_context_with_keywords(user_input, matched_school)
|
| 83 |
|
| 84 |
if structured_facts:
|