Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,7 @@ from requests.exceptions import HTTPError
|
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
from io import BytesIO
|
| 14 |
import base64
|
|
|
|
| 15 |
|
| 16 |
# ๋ก๊น
์ค์
|
| 17 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s:%(message)s', handlers=[logging.StreamHandler()])
|
|
@@ -24,7 +25,8 @@ intents.guilds = True
|
|
| 24 |
intents.guild_messages = True
|
| 25 |
|
| 26 |
# ์ถ๋ก API ํด๋ผ์ด์ธํธ ์ค์
|
| 27 |
-
|
|
|
|
| 28 |
|
| 29 |
# ์ํ ์ ๋ฌธ LLM ํ์ดํ๋ผ์ธ ์ค์
|
| 30 |
math_pipe = pipeline("text-generation", model="AI-MO/NuminaMath-7B-TIR")
|
|
@@ -35,20 +37,19 @@ SPECIFIC_CHANNEL_ID = int(os.getenv("DISCORD_CHANNEL_ID"))
|
|
| 35 |
# ๋ํ ํ์คํ ๋ฆฌ๋ฅผ ์ ์ฅํ ์ ์ญ ๋ณ์
|
| 36 |
conversation_history = []
|
| 37 |
|
| 38 |
-
|
| 39 |
def latex_to_image(latex_string):
|
| 40 |
plt.figure(figsize=(10, 1))
|
| 41 |
plt.axis('off')
|
| 42 |
-
plt.text(0.5, 0.5,
|
| 43 |
|
| 44 |
buffer = BytesIO()
|
| 45 |
plt.savefig(buffer, format='png', bbox_inches='tight', pad_inches=0.1, transparent=True, facecolor='black')
|
| 46 |
buffer.seek(0)
|
| 47 |
|
|
|
|
| 48 |
plt.close()
|
| 49 |
|
| 50 |
-
return
|
| 51 |
-
|
| 52 |
|
| 53 |
def process_and_convert_latex(text):
|
| 54 |
# ๋จ์ผ $ ๋๋ ์ด์ค $$ ๋ก ๋๋ฌ์ธ์ธ LaTeX ์์์ ์ฐพ์ต๋๋ค.
|
|
@@ -71,6 +72,7 @@ class MyClient(discord.Client):
|
|
| 71 |
super().__init__(*args, **kwargs)
|
| 72 |
self.is_processing = False
|
| 73 |
self.math_pipe = math_pipe
|
|
|
|
| 74 |
|
| 75 |
async def on_ready(self):
|
| 76 |
logging.info(f'{self.user}๋ก ๋ก๊ทธ์ธ๋์์ต๋๋ค!')
|
|
@@ -87,11 +89,8 @@ class MyClient(discord.Client):
|
|
| 87 |
|
| 88 |
self.is_processing = True
|
| 89 |
try:
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
else:
|
| 93 |
-
thread = await message.channel.create_thread(name=f"์ง๋ฌธ: {message.author.name}", message=message)
|
| 94 |
-
|
| 95 |
if self.is_math_question(message.content):
|
| 96 |
text_response = await self.handle_math_question(message.content)
|
| 97 |
await self.send_message_with_latex(thread, text_response)
|
|
@@ -100,7 +99,7 @@ class MyClient(discord.Client):
|
|
| 100 |
await self.send_message_with_latex(thread, response)
|
| 101 |
finally:
|
| 102 |
self.is_processing = False
|
| 103 |
-
|
| 104 |
def is_message_in_specific_channel(self, message):
|
| 105 |
return message.channel.id == SPECIFIC_CHANNEL_ID or (
|
| 106 |
isinstance(message.channel, discord.Thread) and message.channel.parent_id == SPECIFIC_CHANNEL_ID
|
|
@@ -119,10 +118,9 @@ class MyClient(discord.Client):
|
|
| 119 |
|
| 120 |
try:
|
| 121 |
# Cohere ๋ชจ๋ธ์๊ฒ AI-MO/NuminaMath-7B-TIR ๋ชจ๋ธ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฒ์ญํ๋๋ก ์์ฒญ
|
| 122 |
-
|
| 123 |
[{"role": "system", "content": "๋ค์ ํ
์คํธ๋ฅผ ํ๊ธ๋ก ๋ฒ์ญํ์ญ์์ค: "}, {"role": "user", "content": math_result}], max_tokens=1000))
|
| 124 |
|
| 125 |
-
cohere_response = await cohere_response_future
|
| 126 |
cohere_result = ''.join([part.choices[0].delta.content for part in cohere_response if part.choices and part.choices[0].delta and part.choices[0].delta.content])
|
| 127 |
|
| 128 |
combined_response = f"์ํ ์ ์๋ ๋ต๋ณ: ```{cohere_result}```"
|
|
@@ -150,7 +148,7 @@ class MyClient(discord.Client):
|
|
| 150 |
messages = [{"role": "system", "content": f"{system_prefix}"}] + conversation_history
|
| 151 |
|
| 152 |
try:
|
| 153 |
-
response = await
|
| 154 |
messages, max_tokens=1000, stream=True, temperature=0.7, top_p=0.85))
|
| 155 |
full_response = ''.join([part.choices[0].delta.content for part in response if part.choices and part.choices[0].delta and part.choices[0].delta.content])
|
| 156 |
conversation_history.append({"role": "assistant", "content": full_response})
|
|
@@ -162,24 +160,25 @@ class MyClient(discord.Client):
|
|
| 162 |
|
| 163 |
async def send_message_with_latex(self, channel, message):
|
| 164 |
try:
|
| 165 |
-
|
| 166 |
text_parts = re.split(r'(\$\$.*?\$\$|\$.*?\$)', message, flags=re.DOTALL)
|
| 167 |
|
| 168 |
for part in text_parts:
|
| 169 |
if part.startswith('$'):
|
| 170 |
-
|
| 171 |
latex_content = part.strip('$')
|
| 172 |
-
|
| 173 |
-
|
|
|
|
| 174 |
else:
|
| 175 |
-
|
| 176 |
if part.strip():
|
| 177 |
await self.send_long_message(channel, part.strip())
|
| 178 |
|
| 179 |
except Exception as e:
|
| 180 |
logging.error(f"Error in send_message_with_latex: {str(e)}")
|
| 181 |
await channel.send("An error occurred while processing the message.")
|
| 182 |
-
|
| 183 |
async def send_long_message(self, channel, message):
|
| 184 |
if len(message) <= 2000:
|
| 185 |
await channel.send(message)
|
|
@@ -188,19 +187,22 @@ class MyClient(discord.Client):
|
|
| 188 |
for part in parts:
|
| 189 |
await channel.send(part)
|
| 190 |
|
| 191 |
-
|
| 192 |
async def retry_request(self, func, retries=5, delay=2):
|
| 193 |
for i in range(retries):
|
| 194 |
try:
|
| 195 |
return await func()
|
| 196 |
except HTTPError as e:
|
| 197 |
-
if e.response.status_code == 503
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
else:
|
| 201 |
raise
|
| 202 |
|
| 203 |
-
|
| 204 |
if __name__ == "__main__":
|
| 205 |
discord_client = MyClient(intents=intents)
|
| 206 |
discord_client.run(os.getenv('DISCORD_TOKEN'))
|
|
|
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
from io import BytesIO
|
| 14 |
import base64
|
| 15 |
+
import time
|
| 16 |
|
| 17 |
# ๋ก๊น
์ค์
|
| 18 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s:%(message)s', handlers=[logging.StreamHandler()])
|
|
|
|
| 25 |
intents.guild_messages = True
|
| 26 |
|
| 27 |
# ์ถ๋ก API ํด๋ผ์ด์ธํธ ์ค์
|
| 28 |
+
hf_client_primary = InferenceClient("CohereForAI/c4ai-command-r-plus", token=os.getenv("HF_TOKEN"))
|
| 29 |
+
hf_client_secondary = InferenceClient("CohereForAI/aya-23-35B", token=os.getenv("HF_TOKEN"))
|
| 30 |
|
| 31 |
# ์ํ ์ ๋ฌธ LLM ํ์ดํ๋ผ์ธ ์ค์
|
| 32 |
math_pipe = pipeline("text-generation", model="AI-MO/NuminaMath-7B-TIR")
|
|
|
|
| 37 |
# ๋ํ ํ์คํ ๋ฆฌ๋ฅผ ์ ์ฅํ ์ ์ญ ๋ณ์
|
| 38 |
conversation_history = []
|
| 39 |
|
|
|
|
| 40 |
def latex_to_image(latex_string):
|
| 41 |
plt.figure(figsize=(10, 1))
|
| 42 |
plt.axis('off')
|
| 43 |
+
plt.text(0.5, 0.5, latex_string, size=20, ha='center', va='center', color='white')
|
| 44 |
|
| 45 |
buffer = BytesIO()
|
| 46 |
plt.savefig(buffer, format='png', bbox_inches='tight', pad_inches=0.1, transparent=True, facecolor='black')
|
| 47 |
buffer.seek(0)
|
| 48 |
|
| 49 |
+
image_base64 = base64.b64encode(buffer.getvalue()).decode()
|
| 50 |
plt.close()
|
| 51 |
|
| 52 |
+
return image_base64
|
|
|
|
| 53 |
|
| 54 |
def process_and_convert_latex(text):
|
| 55 |
# ๋จ์ผ $ ๋๋ ์ด์ค $$ ๋ก ๋๋ฌ์ธ์ธ LaTeX ์์์ ์ฐพ์ต๋๋ค.
|
|
|
|
| 72 |
super().__init__(*args, **kwargs)
|
| 73 |
self.is_processing = False
|
| 74 |
self.math_pipe = math_pipe
|
| 75 |
+
self.hf_client = hf_client_primary # ์ด๊ธฐ ํด๋ผ์ด์ธํธ ์ค์
|
| 76 |
|
| 77 |
async def on_ready(self):
|
| 78 |
logging.info(f'{self.user}๋ก ๋ก๊ทธ์ธ๋์์ต๋๋ค!')
|
|
|
|
| 89 |
|
| 90 |
self.is_processing = True
|
| 91 |
try:
|
| 92 |
+
# ์๋ก์ด ์ค๋ ๋ ์์ฑ
|
| 93 |
+
thread = await message.channel.create_thread(name=f"์ง๋ฌธ: {message.author.name}", message=message)
|
|
|
|
|
|
|
|
|
|
| 94 |
if self.is_math_question(message.content):
|
| 95 |
text_response = await self.handle_math_question(message.content)
|
| 96 |
await self.send_message_with_latex(thread, text_response)
|
|
|
|
| 99 |
await self.send_message_with_latex(thread, response)
|
| 100 |
finally:
|
| 101 |
self.is_processing = False
|
| 102 |
+
|
| 103 |
def is_message_in_specific_channel(self, message):
|
| 104 |
return message.channel.id == SPECIFIC_CHANNEL_ID or (
|
| 105 |
isinstance(message.channel, discord.Thread) and message.channel.parent_id == SPECIFIC_CHANNEL_ID
|
|
|
|
| 118 |
|
| 119 |
try:
|
| 120 |
# Cohere ๋ชจ๋ธ์๊ฒ AI-MO/NuminaMath-7B-TIR ๋ชจ๋ธ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฒ์ญํ๋๋ก ์์ฒญ
|
| 121 |
+
cohere_response = await self.retry_request(lambda: self.hf_client.chat_completion(
|
| 122 |
[{"role": "system", "content": "๋ค์ ํ
์คํธ๋ฅผ ํ๊ธ๋ก ๋ฒ์ญํ์ญ์์ค: "}, {"role": "user", "content": math_result}], max_tokens=1000))
|
| 123 |
|
|
|
|
| 124 |
cohere_result = ''.join([part.choices[0].delta.content for part in cohere_response if part.choices and part.choices[0].delta and part.choices[0].delta.content])
|
| 125 |
|
| 126 |
combined_response = f"์ํ ์ ์๋ ๋ต๋ณ: ```{cohere_result}```"
|
|
|
|
| 148 |
messages = [{"role": "system", "content": f"{system_prefix}"}] + conversation_history
|
| 149 |
|
| 150 |
try:
|
| 151 |
+
response = await self.retry_request(lambda: self.hf_client.chat_completion(
|
| 152 |
messages, max_tokens=1000, stream=True, temperature=0.7, top_p=0.85))
|
| 153 |
full_response = ''.join([part.choices[0].delta.content for part in response if part.choices and part.choices[0].delta and part.choices[0].delta.content])
|
| 154 |
conversation_history.append({"role": "assistant", "content": full_response})
|
|
|
|
| 160 |
|
| 161 |
async def send_message_with_latex(self, channel, message):
|
| 162 |
try:
|
| 163 |
+
# ํ
์คํธ์ LaTeX ์์ ๋ถ๋ฆฌ
|
| 164 |
text_parts = re.split(r'(\$\$.*?\$\$|\$.*?\$)', message, flags=re.DOTALL)
|
| 165 |
|
| 166 |
for part in text_parts:
|
| 167 |
if part.startswith('$'):
|
| 168 |
+
# LaTeX ์์ ์ฒ๋ฆฌ ๋ฐ ์ด๋ฏธ์ง๋ก ์ถ๋ ฅ
|
| 169 |
latex_content = part.strip('$')
|
| 170 |
+
image_base64 = latex_to_image(latex_content)
|
| 171 |
+
image_binary = base64.b64decode(image_base64)
|
| 172 |
+
await channel.send(file=discord.File(BytesIO(image_binary), 'equation.png'))
|
| 173 |
else:
|
| 174 |
+
# ํ
์คํธ ์ถ๋ ฅ
|
| 175 |
if part.strip():
|
| 176 |
await self.send_long_message(channel, part.strip())
|
| 177 |
|
| 178 |
except Exception as e:
|
| 179 |
logging.error(f"Error in send_message_with_latex: {str(e)}")
|
| 180 |
await channel.send("An error occurred while processing the message.")
|
| 181 |
+
|
| 182 |
async def send_long_message(self, channel, message):
|
| 183 |
if len(message) <= 2000:
|
| 184 |
await channel.send(message)
|
|
|
|
| 187 |
for part in parts:
|
| 188 |
await channel.send(part)
|
| 189 |
|
|
|
|
| 190 |
async def retry_request(self, func, retries=5, delay=2):
|
| 191 |
for i in range(retries):
|
| 192 |
try:
|
| 193 |
return await func()
|
| 194 |
except HTTPError as e:
|
| 195 |
+
if e.response.status_code == 503:
|
| 196 |
+
if i < retries - 1:
|
| 197 |
+
logging.warning(f"503 error encountered. Retrying in {delay} seconds...")
|
| 198 |
+
await asyncio.sleep(delay)
|
| 199 |
+
else:
|
| 200 |
+
logging.warning("Switching to secondary model due to repeated 503 errors.")
|
| 201 |
+
self.hf_client = hf_client_secondary
|
| 202 |
+
await asyncio.sleep(delay)
|
| 203 |
else:
|
| 204 |
raise
|
| 205 |
|
|
|
|
| 206 |
if __name__ == "__main__":
|
| 207 |
discord_client = MyClient(intents=intents)
|
| 208 |
discord_client.run(os.getenv('DISCORD_TOKEN'))
|