alaselababatunde commited on
Commit
748afe9
Β·
1 Parent(s): 730cf2c
Files changed (1) hide show
  1. app.py +97 -86
app.py CHANGED
@@ -5,154 +5,165 @@ from fastapi.responses import JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
  from langchain.prompts import PromptTemplate
8
- from transformers import pipeline
 
 
 
9
 
10
  # ==============================
11
- # Logging
12
  # ==============================
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger("AgriCopilot")
15
 
16
  # ==============================
17
- # Auth
 
 
 
 
 
 
 
 
 
18
  # ==============================
19
  PROJECT_API_KEY = os.getenv("PROJECT_API_KEY")
20
 
21
  def check_auth(authorization: str | None):
 
22
  if not authorization or not authorization.startswith("Bearer "):
23
  raise HTTPException(status_code=401, detail="Missing bearer token")
24
  token = authorization.split(" ", 1)[1]
25
  if token != PROJECT_API_KEY:
26
- raise HTTPException(status_code=403, detail="Invalid bearer token")
27
 
28
  # ==============================
29
- # FastAPI Init
30
  # ==============================
31
- app = FastAPI(title="AgriCopilot")
32
-
33
- app.add_middleware(
34
- CORSMiddleware,
35
- allow_origins=["*"], # βœ… change to frontend URL in prod
36
- allow_credentials=True,
37
- allow_methods=["*"],
38
- allow_headers=["*"],
39
- )
40
 
41
  # ==============================
42
- # Models via Transformers Pipelines
43
  # ==============================
44
- logger.info("Loading models... this may take a while.")
 
45
 
46
- crop_pipeline = pipeline(
47
- "text-generation",
48
- model="meta-llama/Llama-3.2-11B-Vision-Instruct",
49
- torch_dtype="auto",
50
- device_map="auto"
51
- )
52
 
53
- chat_pipeline = pipeline(
54
- "text-generation",
55
- model="meta-llama/Llama-3.1-8B-Instruct",
56
- torch_dtype="auto",
57
- device_map="auto"
58
- )
59
 
60
- disaster_pipeline = pipeline(
61
- "text-generation",
62
- model="meta-llama/Llama-3.1-8B-Instruct",
63
- torch_dtype="auto",
64
- device_map="auto"
65
- )
66
 
67
- market_pipeline = pipeline(
68
- "text-generation",
69
- model="meta-llama/Llama-3.1-8B-Instruct",
70
- torch_dtype="auto",
71
- device_map="auto"
72
- )
73
 
74
  # ==============================
75
- # Prompt Templates
76
  # ==============================
 
 
77
  crop_template = PromptTemplate(
78
  input_variables=["symptoms"],
79
- template=(
80
- "You are AgriCopilot, a multilingual AI crop doctor. "
81
- "Farmer reports: {symptoms}. Diagnose the issue and suggest treatments "
82
- "in simple farmer-friendly language."
83
- )
 
 
 
 
 
84
  )
85
 
 
86
  chat_template = PromptTemplate(
87
  input_variables=["query"],
88
- template=(
89
- "You are AgriCopilot, a supportive multilingual assistant for farmers. "
90
- "Respond in the same language. Farmer says: {query}"
91
- )
 
 
 
 
 
 
92
  )
93
 
 
94
  disaster_template = PromptTemplate(
95
  input_variables=["report"],
96
- template=(
97
- "You are AgriCopilot, an AI disaster assistant. Summarize this report "
98
- "into 3–5 clear steps farmers can follow. Report: {report}"
99
- )
 
 
 
 
 
 
100
  )
101
 
 
102
  market_template = PromptTemplate(
103
  input_variables=["product"],
104
- template=(
105
- "You are AgriCopilot, an AI marketplace advisor. Farmer wants to sell or buy: {product}. "
106
- "Suggest matches, advice, and safe practices."
107
- )
 
 
 
 
 
 
108
  )
109
 
110
  # ==============================
111
- # Request Models
112
- # ==============================
113
- class CropRequest(BaseModel):
114
- symptoms: str
115
-
116
- class ChatRequest(BaseModel):
117
- query: str
118
-
119
- class DisasterRequest(BaseModel):
120
- report: str
121
-
122
- class MarketRequest(BaseModel):
123
- product: str
124
-
125
- # ==============================
126
- # Routes
127
  # ==============================
128
- @app.get("/")
129
- async def root():
130
- return {"status": "AgriCopilot AI Backend running with Transformers"}
131
-
132
  @app.post("/crop-doctor")
133
  async def crop_doctor(req: CropRequest, authorization: str | None = Header(None)):
134
  check_auth(authorization)
135
  prompt = crop_template.format(symptoms=req.symptoms)
136
- response = crop_pipeline(prompt, max_new_tokens=512, do_sample=True)
137
- return {"diagnosis": response[0]["generated_text"]}
138
 
139
  @app.post("/multilingual-chat")
140
  async def multilingual_chat(req: ChatRequest, authorization: str | None = Header(None)):
141
  check_auth(authorization)
142
  prompt = chat_template.format(query=req.query)
143
- response = chat_pipeline(prompt, max_new_tokens=512, do_sample=True)
144
- return {"reply": response[0]["generated_text"]}
145
 
146
  @app.post("/disaster-summarizer")
147
  async def disaster_summarizer(req: DisasterRequest, authorization: str | None = Header(None)):
148
  check_auth(authorization)
149
  prompt = disaster_template.format(report=req.report)
150
- response = disaster_pipeline(prompt, max_new_tokens=512, do_sample=True)
151
- return {"summary": response[0]["generated_text"]}
152
 
153
  @app.post("/marketplace")
154
  async def marketplace(req: MarketRequest, authorization: str | None = Header(None)):
155
  check_auth(authorization)
156
  prompt = market_template.format(product=req.product)
157
- response = market_pipeline(prompt, max_new_tokens=512, do_sample=True)
158
- return {"recommendation": response[0]["generated_text"]}
 
 
 
 
 
 
 
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
  from langchain.prompts import PromptTemplate
8
+ from langchain_huggingface import HuggingFaceEndpoint
9
+ from huggingface_hub.utils import HfHubHTTPError
10
+ from langchain.schema import HumanMessage
11
+ from vector import query_vector
12
 
13
  # ==============================
14
+ # Setup Logging
15
  # ==============================
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger("AgriCopilot")
18
 
19
  # ==============================
20
+ # App Init
21
+ # ==============================
22
+ app = FastAPI(title="AgriCopilot")
23
+
24
+ @app.get("/")
25
+ async def root():
26
+ return {"status": "AgriCopilot AI Backend is working perfectly"}
27
+
28
+ # ==============================
29
+ # AUTH CONFIG
30
  # ==============================
31
  PROJECT_API_KEY = os.getenv("PROJECT_API_KEY")
32
 
33
  def check_auth(authorization: str | None):
34
+ """Validate Bearer token against PROJECT_API_KEY"""
35
  if not authorization or not authorization.startswith("Bearer "):
36
  raise HTTPException(status_code=401, detail="Missing bearer token")
37
  token = authorization.split(" ", 1)[1]
38
  if token != PROJECT_API_KEY:
39
+ raise HTTPException(status_code=403, detail="Invalid token")
40
 
41
  # ==============================
42
+ # Global Exception Handler
43
  # ==============================
44
+ @app.exception_handler(Exception)
45
+ async def global_exception_handler(request: Request, exc: Exception):
46
+ logger.error(f"Unhandled error: {exc}")
47
+ return JSONResponse(
48
+ status_code=500,
49
+ content={"error": str(exc)},
50
+ )
 
 
51
 
52
  # ==============================
53
+ # Request Models
54
  # ==============================
55
+ class CropRequest(BaseModel):
56
+ symptoms: str
57
 
58
+ class ChatRequest(BaseModel):
59
+ query: str
 
 
 
 
60
 
61
+ class DisasterRequest(BaseModel):
62
+ report: str
 
 
 
 
63
 
64
+ class MarketRequest(BaseModel):
65
+ product: str
 
 
 
 
66
 
67
+ class VectorRequest(BaseModel):
68
+ query: str
 
 
 
 
69
 
70
  # ==============================
71
+ # MODELS PER ENDPOINT (Meta Models, Conversational)
72
  # ==============================
73
+
74
+ # 1. Crop Doctor
75
  crop_template = PromptTemplate(
76
  input_variables=["symptoms"],
77
+ template="You are AgriCopilot, a multilingual AI assistant created to support farmers. Farmer reports: {symptoms}. Diagnose the most likely disease and suggest treatments in simple farmer-friendly language."
78
+ )
79
+ crop_llm = HuggingFaceEndpoint(
80
+ repo_id="meta-llama/Llama-3.2-11B-Vision-Instruct",
81
+ task="conversational", # βœ… FIXED
82
+ temperature=0.3,
83
+ top_p=0.9,
84
+ do_sample=True,
85
+ repetition_penalty=1.1,
86
+ max_new_tokens=1024
87
  )
88
 
89
+ # 2. Multilingual Chat
90
  chat_template = PromptTemplate(
91
  input_variables=["query"],
92
+ template="You are AgriCopilot, a supportive multilingual AI guide built for farmers. Farmer says: {query}"
93
+ )
94
+ chat_llm = HuggingFaceEndpoint(
95
+ repo_id="meta-llama/Llama-3.1-8B-Instruct",
96
+ task="conversational", # βœ… FIXED
97
+ temperature=0.3,
98
+ top_p=0.9,
99
+ do_sample=True,
100
+ repetition_penalty=1.1,
101
+ max_new_tokens=1024
102
  )
103
 
104
+ # 3. Disaster Summarizer
105
  disaster_template = PromptTemplate(
106
  input_variables=["report"],
107
+ template="You are AgriCopilot, an AI disaster-response assistant. Summarize in simple steps: {report}"
108
+ )
109
+ disaster_llm = HuggingFaceEndpoint(
110
+ repo_id="meta-llama/Llama-3.1-8B-Instruct",
111
+ task="conversational", # βœ… FIXED
112
+ temperature=0.3,
113
+ top_p=0.9,
114
+ do_sample=True,
115
+ repetition_penalty=1.1,
116
+ max_new_tokens=1024
117
  )
118
 
119
+ # 4. Marketplace Recommendation
120
  market_template = PromptTemplate(
121
  input_variables=["product"],
122
+ template="You are AgriCopilot, an AI agricultural marketplace advisor. Farmer wants to sell or buy: {product}. Suggest best options and advice."
123
+ )
124
+ market_llm = HuggingFaceEndpoint(
125
+ repo_id="meta-llama/Llama-3.1-8B-Instruct",
126
+ task="conversational", # βœ… FIXED
127
+ temperature=0.3,
128
+ top_p=0.9,
129
+ do_sample=True,
130
+ repetition_penalty=1.1,
131
+ max_new_tokens=1024
132
  )
133
 
134
  # ==============================
135
+ # ENDPOINTS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  # ==============================
 
 
 
 
137
  @app.post("/crop-doctor")
138
  async def crop_doctor(req: CropRequest, authorization: str | None = Header(None)):
139
  check_auth(authorization)
140
  prompt = crop_template.format(symptoms=req.symptoms)
141
+ response = crop_llm.invoke([HumanMessage(content=prompt)])
142
+ return {"diagnosis": str(response)}
143
 
144
  @app.post("/multilingual-chat")
145
  async def multilingual_chat(req: ChatRequest, authorization: str | None = Header(None)):
146
  check_auth(authorization)
147
  prompt = chat_template.format(query=req.query)
148
+ response = chat_llm.invoke([HumanMessage(content=prompt)])
149
+ return {"reply": str(response)}
150
 
151
  @app.post("/disaster-summarizer")
152
  async def disaster_summarizer(req: DisasterRequest, authorization: str | None = Header(None)):
153
  check_auth(authorization)
154
  prompt = disaster_template.format(report=req.report)
155
+ response = disaster_llm.invoke([HumanMessage(content=prompt)])
156
+ return {"summary": str(response)}
157
 
158
  @app.post("/marketplace")
159
  async def marketplace(req: MarketRequest, authorization: str | None = Header(None)):
160
  check_auth(authorization)
161
  prompt = market_template.format(product=req.product)
162
+ response = market_llm.invoke([HumanMessage(content=prompt)])
163
+ return {"recommendation": str(response)}
164
+
165
+ @app.post("/vector-search")
166
+ async def vector_search(req: VectorRequest, authorization: str | None = Header(None)):
167
+ check_auth(authorization)
168
+ results = query_vector(req.query)
169
+ return {"results": results}