alaselababatunde commited on
Commit
730cf2c
·
1 Parent(s): 41a9cbc
Files changed (2) hide show
  1. app.py +86 -97
  2. requirements.txt +5 -1
app.py CHANGED
@@ -5,165 +5,154 @@ from fastapi.responses import JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
  from langchain.prompts import PromptTemplate
8
- from langchain_huggingface import HuggingFaceEndpoint
9
- from huggingface_hub.utils import HfHubHTTPError
10
- from langchain.schema import HumanMessage
11
- from vector import query_vector
12
 
13
  # ==============================
14
- # Setup Logging
15
  # ==============================
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger("AgriCopilot")
18
 
19
  # ==============================
20
- # App Init
21
- # ==============================
22
- app = FastAPI(title="AgriCopilot")
23
-
24
- @app.get("/")
25
- async def root():
26
- return {"status": "AgriCopilot AI Backend is working perfectly"}
27
-
28
- # ==============================
29
- # AUTH CONFIG
30
  # ==============================
31
  PROJECT_API_KEY = os.getenv("PROJECT_API_KEY")
32
 
33
  def check_auth(authorization: str | None):
34
- """Validate Bearer token against PROJECT_API_KEY"""
35
  if not authorization or not authorization.startswith("Bearer "):
36
  raise HTTPException(status_code=401, detail="Missing bearer token")
37
  token = authorization.split(" ", 1)[1]
38
  if token != PROJECT_API_KEY:
39
- raise HTTPException(status_code=403, detail="Invalid token")
40
 
41
  # ==============================
42
- # Global Exception Handler
43
  # ==============================
44
- @app.exception_handler(Exception)
45
- async def global_exception_handler(request: Request, exc: Exception):
46
- logger.error(f"Unhandled error: {exc}")
47
- return JSONResponse(
48
- status_code=500,
49
- content={"error": str(exc)},
50
- )
 
 
51
 
52
  # ==============================
53
- # Request Models
54
  # ==============================
55
- class CropRequest(BaseModel):
56
- symptoms: str
57
 
58
- class ChatRequest(BaseModel):
59
- query: str
 
 
 
 
60
 
61
- class DisasterRequest(BaseModel):
62
- report: str
 
 
 
 
63
 
64
- class MarketRequest(BaseModel):
65
- product: str
 
 
 
 
66
 
67
- class VectorRequest(BaseModel):
68
- query: str
 
 
 
 
69
 
70
  # ==============================
71
- # MODELS PER ENDPOINT (Meta Models, Conversational)
72
  # ==============================
73
-
74
- # 1. Crop Doctor
75
  crop_template = PromptTemplate(
76
  input_variables=["symptoms"],
77
- template="You are AgriCopilot, a multilingual AI assistant created to support farmers. Farmer reports: {symptoms}. Diagnose the most likely disease and suggest treatments in simple farmer-friendly language."
78
- )
79
- crop_llm = HuggingFaceEndpoint(
80
- repo_id="meta-llama/Llama-3.2-11B-Vision-Instruct",
81
- task="conversational", # ✅ FIXED
82
- temperature=0.3,
83
- top_p=0.9,
84
- do_sample=True,
85
- repetition_penalty=1.1,
86
- max_new_tokens=1024
87
  )
88
 
89
- # 2. Multilingual Chat
90
  chat_template = PromptTemplate(
91
  input_variables=["query"],
92
- template="You are AgriCopilot, a supportive multilingual AI guide built for farmers. Farmer says: {query}"
93
- )
94
- chat_llm = HuggingFaceEndpoint(
95
- repo_id="meta-llama/Llama-3.1-8B-Instruct",
96
- task="conversational", # ✅ FIXED
97
- temperature=0.3,
98
- top_p=0.9,
99
- do_sample=True,
100
- repetition_penalty=1.1,
101
- max_new_tokens=1024
102
  )
103
 
104
- # 3. Disaster Summarizer
105
  disaster_template = PromptTemplate(
106
  input_variables=["report"],
107
- template="You are AgriCopilot, an AI disaster-response assistant. Summarize in simple steps: {report}"
108
- )
109
- disaster_llm = HuggingFaceEndpoint(
110
- repo_id="meta-llama/Llama-3.1-8B-Instruct",
111
- task="conversational", # ✅ FIXED
112
- temperature=0.3,
113
- top_p=0.9,
114
- do_sample=True,
115
- repetition_penalty=1.1,
116
- max_new_tokens=1024
117
  )
118
 
119
- # 4. Marketplace Recommendation
120
  market_template = PromptTemplate(
121
  input_variables=["product"],
122
- template="You are AgriCopilot, an AI agricultural marketplace advisor. Farmer wants to sell or buy: {product}. Suggest best options and advice."
123
- )
124
- market_llm = HuggingFaceEndpoint(
125
- repo_id="meta-llama/Llama-3.1-8B-Instruct",
126
- task="conversational", # ✅ FIXED
127
- temperature=0.3,
128
- top_p=0.9,
129
- do_sample=True,
130
- repetition_penalty=1.1,
131
- max_new_tokens=1024
132
  )
133
 
134
  # ==============================
135
- # ENDPOINTS
136
  # ==============================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  @app.post("/crop-doctor")
138
  async def crop_doctor(req: CropRequest, authorization: str | None = Header(None)):
139
  check_auth(authorization)
140
  prompt = crop_template.format(symptoms=req.symptoms)
141
- response = crop_llm.invoke([HumanMessage(content=prompt)])
142
- return {"diagnosis": str(response)}
143
 
144
  @app.post("/multilingual-chat")
145
  async def multilingual_chat(req: ChatRequest, authorization: str | None = Header(None)):
146
  check_auth(authorization)
147
  prompt = chat_template.format(query=req.query)
148
- response = chat_llm.invoke([HumanMessage(content=prompt)])
149
- return {"reply": str(response)}
150
 
151
  @app.post("/disaster-summarizer")
152
  async def disaster_summarizer(req: DisasterRequest, authorization: str | None = Header(None)):
153
  check_auth(authorization)
154
  prompt = disaster_template.format(report=req.report)
155
- response = disaster_llm.invoke([HumanMessage(content=prompt)])
156
- return {"summary": str(response)}
157
 
158
  @app.post("/marketplace")
159
  async def marketplace(req: MarketRequest, authorization: str | None = Header(None)):
160
  check_auth(authorization)
161
  prompt = market_template.format(product=req.product)
162
- response = market_llm.invoke([HumanMessage(content=prompt)])
163
- return {"recommendation": str(response)}
164
-
165
- @app.post("/vector-search")
166
- async def vector_search(req: VectorRequest, authorization: str | None = Header(None)):
167
- check_auth(authorization)
168
- results = query_vector(req.query)
169
- return {"results": results}
 
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
  from langchain.prompts import PromptTemplate
8
+ from transformers import pipeline
 
 
 
9
 
10
  # ==============================
11
+ # Logging
12
  # ==============================
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger("AgriCopilot")
15
 
16
  # ==============================
17
+ # Auth
 
 
 
 
 
 
 
 
 
18
  # ==============================
19
  PROJECT_API_KEY = os.getenv("PROJECT_API_KEY")
20
 
21
  def check_auth(authorization: str | None):
 
22
  if not authorization or not authorization.startswith("Bearer "):
23
  raise HTTPException(status_code=401, detail="Missing bearer token")
24
  token = authorization.split(" ", 1)[1]
25
  if token != PROJECT_API_KEY:
26
+ raise HTTPException(status_code=403, detail="Invalid bearer token")
27
 
28
  # ==============================
29
+ # FastAPI Init
30
  # ==============================
31
+ app = FastAPI(title="AgriCopilot")
32
+
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=["*"], # ✅ change to frontend URL in prod
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
 
41
  # ==============================
42
+ # Models via Transformers Pipelines
43
  # ==============================
44
+ logger.info("Loading models... this may take a while.")
 
45
 
46
+ crop_pipeline = pipeline(
47
+ "text-generation",
48
+ model="meta-llama/Llama-3.2-11B-Vision-Instruct",
49
+ torch_dtype="auto",
50
+ device_map="auto"
51
+ )
52
 
53
+ chat_pipeline = pipeline(
54
+ "text-generation",
55
+ model="meta-llama/Llama-3.1-8B-Instruct",
56
+ torch_dtype="auto",
57
+ device_map="auto"
58
+ )
59
 
60
+ disaster_pipeline = pipeline(
61
+ "text-generation",
62
+ model="meta-llama/Llama-3.1-8B-Instruct",
63
+ torch_dtype="auto",
64
+ device_map="auto"
65
+ )
66
 
67
+ market_pipeline = pipeline(
68
+ "text-generation",
69
+ model="meta-llama/Llama-3.1-8B-Instruct",
70
+ torch_dtype="auto",
71
+ device_map="auto"
72
+ )
73
 
74
  # ==============================
75
+ # Prompt Templates
76
  # ==============================
 
 
77
  crop_template = PromptTemplate(
78
  input_variables=["symptoms"],
79
+ template=(
80
+ "You are AgriCopilot, a multilingual AI crop doctor. "
81
+ "Farmer reports: {symptoms}. Diagnose the issue and suggest treatments "
82
+ "in simple farmer-friendly language."
83
+ )
 
 
 
 
 
84
  )
85
 
 
86
  chat_template = PromptTemplate(
87
  input_variables=["query"],
88
+ template=(
89
+ "You are AgriCopilot, a supportive multilingual assistant for farmers. "
90
+ "Respond in the same language. Farmer says: {query}"
91
+ )
 
 
 
 
 
 
92
  )
93
 
 
94
  disaster_template = PromptTemplate(
95
  input_variables=["report"],
96
+ template=(
97
+ "You are AgriCopilot, an AI disaster assistant. Summarize this report "
98
+ "into 3–5 clear steps farmers can follow. Report: {report}"
99
+ )
 
 
 
 
 
 
100
  )
101
 
 
102
  market_template = PromptTemplate(
103
  input_variables=["product"],
104
+ template=(
105
+ "You are AgriCopilot, an AI marketplace advisor. Farmer wants to sell or buy: {product}. "
106
+ "Suggest matches, advice, and safe practices."
107
+ )
 
 
 
 
 
 
108
  )
109
 
110
  # ==============================
111
+ # Request Models
112
  # ==============================
113
+ class CropRequest(BaseModel):
114
+ symptoms: str
115
+
116
+ class ChatRequest(BaseModel):
117
+ query: str
118
+
119
+ class DisasterRequest(BaseModel):
120
+ report: str
121
+
122
+ class MarketRequest(BaseModel):
123
+ product: str
124
+
125
+ # ==============================
126
+ # Routes
127
+ # ==============================
128
+ @app.get("/")
129
+ async def root():
130
+ return {"status": "AgriCopilot AI Backend running with Transformers"}
131
+
132
  @app.post("/crop-doctor")
133
  async def crop_doctor(req: CropRequest, authorization: str | None = Header(None)):
134
  check_auth(authorization)
135
  prompt = crop_template.format(symptoms=req.symptoms)
136
+ response = crop_pipeline(prompt, max_new_tokens=512, do_sample=True)
137
+ return {"diagnosis": response[0]["generated_text"]}
138
 
139
  @app.post("/multilingual-chat")
140
  async def multilingual_chat(req: ChatRequest, authorization: str | None = Header(None)):
141
  check_auth(authorization)
142
  prompt = chat_template.format(query=req.query)
143
+ response = chat_pipeline(prompt, max_new_tokens=512, do_sample=True)
144
+ return {"reply": response[0]["generated_text"]}
145
 
146
  @app.post("/disaster-summarizer")
147
  async def disaster_summarizer(req: DisasterRequest, authorization: str | None = Header(None)):
148
  check_auth(authorization)
149
  prompt = disaster_template.format(report=req.report)
150
+ response = disaster_pipeline(prompt, max_new_tokens=512, do_sample=True)
151
+ return {"summary": response[0]["generated_text"]}
152
 
153
  @app.post("/marketplace")
154
  async def marketplace(req: MarketRequest, authorization: str | None = Header(None)):
155
  check_auth(authorization)
156
  prompt = market_template.format(product=req.product)
157
+ response = market_pipeline(prompt, max_new_tokens=512, do_sample=True)
158
+ return {"recommendation": response[0]["generated_text"]}
 
 
 
 
 
 
requirements.txt CHANGED
@@ -9,4 +9,8 @@ langchain
9
  langchain-huggingface
10
  kagglehub
11
  pandas
12
- datasets
 
 
 
 
 
9
  langchain-huggingface
10
  kagglehub
11
  pandas
12
+ datasets
13
+ transformers
14
+ accelerate
15
+ torch
16
+ sentencepiece