Spaces:
Sleeping
Sleeping
Ilyas KHIAT
commited on
Commit
·
b66e2f4
1
Parent(s):
6c65306
changing base prompt
Browse files
main.py
CHANGED
|
@@ -10,6 +10,8 @@ from rag import *
|
|
| 10 |
from fastapi.responses import StreamingResponse
|
| 11 |
import json
|
| 12 |
from prompts import *
|
|
|
|
|
|
|
| 13 |
|
| 14 |
load_dotenv()
|
| 15 |
|
|
@@ -42,12 +44,16 @@ class StyleWriter(BaseModel):
|
|
| 42 |
style: str
|
| 43 |
tonality: str
|
| 44 |
|
|
|
|
|
|
|
| 45 |
class UserInput(BaseModel):
|
| 46 |
prompt: str
|
| 47 |
enterprise_id: str
|
| 48 |
stream: Optional[bool] = False
|
| 49 |
messages: Optional[list[dict]] = []
|
| 50 |
style_tonality: Optional[StyleWriter] = None
|
|
|
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
class EnterpriseData(BaseModel):
|
|
@@ -175,11 +181,11 @@ def generate_answer(user_input: UserInput):
|
|
| 175 |
context = ""
|
| 176 |
|
| 177 |
if user_input.style_tonality is None:
|
| 178 |
-
prompt_formated = prompt_reformatting(template_prompt,context,prompt)
|
| 179 |
-
answer = generate_response_via_langchain(prompt, model="gpt-4o",stream=user_input.stream,context = context , messages=user_input.messages,template=template_prompt)
|
| 180 |
else:
|
| 181 |
-
prompt_formated = prompt_reformatting(template_prompt,context,prompt,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality)
|
| 182 |
-
answer = generate_response_via_langchain(prompt,
|
| 183 |
|
| 184 |
if user_input.stream:
|
| 185 |
return StreamingResponse(stream_generator(answer,prompt_formated), media_type="application/json")
|
|
@@ -192,6 +198,10 @@ def generate_answer(user_input: UserInput):
|
|
| 192 |
|
| 193 |
except Exception as e:
|
| 194 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
|
|
|
|
| 10 |
from fastapi.responses import StreamingResponse
|
| 11 |
import json
|
| 12 |
from prompts import *
|
| 13 |
+
from typing import Literal
|
| 14 |
+
from models import *
|
| 15 |
|
| 16 |
load_dotenv()
|
| 17 |
|
|
|
|
| 44 |
style: str
|
| 45 |
tonality: str
|
| 46 |
|
| 47 |
+
models = ["gpt-4o","gpt-4o-mini","mistral-large-latest"]
|
| 48 |
+
|
| 49 |
class UserInput(BaseModel):
|
| 50 |
prompt: str
|
| 51 |
enterprise_id: str
|
| 52 |
stream: Optional[bool] = False
|
| 53 |
messages: Optional[list[dict]] = []
|
| 54 |
style_tonality: Optional[StyleWriter] = None
|
| 55 |
+
marque: Optional[str] = None
|
| 56 |
+
model: Literal["gpt-4o","gpt-4o-mini","mistral-large-latest"] = "gpt-4o"
|
| 57 |
|
| 58 |
|
| 59 |
class EnterpriseData(BaseModel):
|
|
|
|
| 181 |
context = ""
|
| 182 |
|
| 183 |
if user_input.style_tonality is None:
|
| 184 |
+
prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque",""))
|
| 185 |
+
answer = generate_response_via_langchain(prompt, model=getattr(user_input,"model","gpt-4o"),stream=user_input.stream,context = context , messages=user_input.messages,template=template_prompt,enterprise_name=getattr(user_input,"marque",""))
|
| 186 |
else:
|
| 187 |
+
prompt_formated = prompt_reformatting(template_prompt,context,prompt,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality,enterprise_name=getattr(user_input,"marque",""))
|
| 188 |
+
answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"),stream=user_input.stream,context = context , messages=user_input.messages,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality,template=template_prompt,enterprise_name=getattr(user_input,"marque",""))
|
| 189 |
|
| 190 |
if user_input.stream:
|
| 191 |
return StreamingResponse(stream_generator(answer,prompt_formated), media_type="application/json")
|
|
|
|
| 198 |
|
| 199 |
except Exception as e:
|
| 200 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
| 201 |
+
|
| 202 |
+
@app.get("/models")
|
| 203 |
+
def get_models():
|
| 204 |
+
return {"models": models}
|
| 205 |
|
| 206 |
|
| 207 |
|
models.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
models = ["gpt-4o","gpt-4o-mini","mistral-large-latest"]
|
rag.py
CHANGED
|
@@ -7,6 +7,7 @@ from langchain_core.documents import Document
|
|
| 7 |
from langchain_openai import ChatOpenAI
|
| 8 |
from langchain_core.output_parsers import StrOutputParser
|
| 9 |
from langchain_core.prompts import PromptTemplate
|
|
|
|
| 10 |
from uuid import uuid4
|
| 11 |
|
| 12 |
import unicodedata
|
|
@@ -105,9 +106,16 @@ def generate_response_via_langchain(query: str, stream: bool = False, model: str
|
|
| 105 |
|
| 106 |
|
| 107 |
prompt = PromptTemplate.from_template(template)
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
# Initialize the OpenAI LLM with the specified model
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
# Create an LLM chain with the prompt and the LLM
|
| 113 |
llm_chain = prompt | llm | StrOutputParser()
|
|
|
|
| 7 |
from langchain_openai import ChatOpenAI
|
| 8 |
from langchain_core.output_parsers import StrOutputParser
|
| 9 |
from langchain_core.prompts import PromptTemplate
|
| 10 |
+
from langchain_mistralai import ChatMistralAI
|
| 11 |
from uuid import uuid4
|
| 12 |
|
| 13 |
import unicodedata
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
prompt = PromptTemplate.from_template(template)
|
| 109 |
+
|
| 110 |
+
print(f"model: {model}")
|
| 111 |
+
print(f"marque: {enterprise_name}")
|
| 112 |
|
| 113 |
# Initialize the OpenAI LLM with the specified model
|
| 114 |
+
if model.startswith("gpt"):
|
| 115 |
+
llm = ChatOpenAI(model=model,temperature=0)
|
| 116 |
+
if model.startswith("mistral"):
|
| 117 |
+
llm = ChatMistralAI(model=model,temperature=0)
|
| 118 |
+
|
| 119 |
|
| 120 |
# Create an LLM chain with the prompt and the LLM
|
| 121 |
llm_chain = prompt | llm | StrOutputParser()
|
requirements.txt
CHANGED
|
@@ -13,4 +13,4 @@ langchain
|
|
| 13 |
langchain-openai
|
| 14 |
langchain-community
|
| 15 |
langchain-pinecone
|
| 16 |
-
|
|
|
|
| 13 |
langchain-openai
|
| 14 |
langchain-community
|
| 15 |
langchain-pinecone
|
| 16 |
+
langchain_mistralai
|