|
|
import gradio as gr |
|
|
import asyncio |
|
|
import os |
|
|
import json |
|
|
import tiktoken |
|
|
from typing import List, Dict, Any |
|
|
from mcp import ClientSession, StdioServerParameters |
|
|
from mcp.client.stdio import stdio_client |
|
|
from openai import AsyncOpenAI |
|
|
|
|
|
|
|
|
enc = tiktoken.get_encoding("o200k_base") |
|
|
|
|
|
def count_tokens(text: str) -> int: |
|
|
return len(enc.encode(text)) |
|
|
|
|
|
class DemoSession: |
|
|
def __init__(self, mode: str): |
|
|
self.mode = mode |
|
|
self.server_process = None |
|
|
self.session = None |
|
|
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) |
|
|
self.history = [] |
|
|
|
|
|
|
|
|
self.initial_tokens = 0 |
|
|
self.runtime_tokens = 0 |
|
|
self.tool_calls_count = 0 |
|
|
self.tools_list = [] |
|
|
|
|
|
|
|
|
self.exit_stack = None |
|
|
|
|
|
async def start(self): |
|
|
|
|
|
server_params = StdioServerParameters( |
|
|
command="python3", |
|
|
args=["app/server/main.py"], |
|
|
env={**os.environ, "MCP_MODE": self.mode} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
|
|
|
class PersistentAgent: |
|
|
def __init__(self, mode, api_key): |
|
|
self.mode = mode |
|
|
self.api_key = api_key |
|
|
self.stack = contextlib.AsyncExitStack() |
|
|
self.session = None |
|
|
self.tools = None |
|
|
self.metrics = {"initial": 0, "runtime": 0} |
|
|
self.messages = [] |
|
|
if not api_key: |
|
|
raise ValueError("API Key is required") |
|
|
self.openai = AsyncOpenAI(api_key=api_key) |
|
|
|
|
|
async def initialize(self): |
|
|
server_params = StdioServerParameters( |
|
|
command="python3", |
|
|
args=["app/server/main.py"], |
|
|
env={**os.environ, "MCP_MODE": self.mode, "PYTHONUNBUFFERED": "1"} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.read, self.write = await self.stack.enter_async_context(stdio_client(server_params)) |
|
|
self.session = await self.stack.enter_async_context(ClientSession(self.read, self.write)) |
|
|
await self.session.initialize() |
|
|
|
|
|
|
|
|
tools_result = await self.session.list_tools() |
|
|
self.tools = tools_result.tools |
|
|
|
|
|
|
|
|
try: |
|
|
resources_result = await self.session.list_resources() |
|
|
resources_json = json.dumps([r.model_dump() for r in resources_result.resources], indent=2) |
|
|
except: |
|
|
resources_json = "" |
|
|
|
|
|
|
|
|
tools_json = json.dumps([t.model_dump() for t in self.tools], indent=2) |
|
|
system_base = "You are a helpful assistant..." |
|
|
|
|
|
self.metrics["initial"] = count_tokens(tools_json) + count_tokens(resources_json) + count_tokens(system_base) |
|
|
|
|
|
|
|
|
tool_desc_str = "\n".join([f" - {t.name}: {t.description}" for t in self.tools]) |
|
|
self.system_prompt = f"{system_base}\n\nYou have access to these tools:\n\n{tool_desc_str}\n\nUse tools when needed to answer questions." |
|
|
|
|
|
if self.mode == 'progressive': |
|
|
self.system_prompt += """ |
|
|
|
|
|
IMPORTANT - Tool Usage Workflow: |
|
|
This server uses progressive disclosure for tools. Follow this exact workflow: |
|
|
|
|
|
1. PICK the right tool based on the descriptions above (they tell you WHAT each tool does) |
|
|
2. FETCH the full tool description using read_resource with the specific tool name |
|
|
Example: read_resource(resource_uri="resource:///tool_descriptions?tools=TOOL_NAME") |
|
|
3. CALL the tool using the parameters you just learned |
|
|
|
|
|
DO NOT try to fetch tool_descriptions without specifying which tool you want (?tools=TOOL_NAME). |
|
|
The tool descriptions above are sufficient for choosing which tool you need. |
|
|
You fetch the full description to learn the parameters and authorize the tool.""" |
|
|
|
|
|
self.messages = [{"role": "system", "content": self.system_prompt}] |
|
|
|
|
|
async def chat(self, user_message): |
|
|
self.messages.append({"role": "user", "content": user_message}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logs = [] |
|
|
|
|
|
while True: |
|
|
|
|
|
openai_tools = [] |
|
|
for t in self.tools: |
|
|
|
|
|
schema = t.inputSchema |
|
|
if schema.get("type") == "object" and "properties" not in schema: |
|
|
|
|
|
schema = {"type": "object", "properties": {}} |
|
|
|
|
|
openai_tools.append({ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": t.name, |
|
|
"description": t.description, |
|
|
"parameters": schema |
|
|
} |
|
|
}) |
|
|
|
|
|
if self.mode == 'progressive': |
|
|
openai_tools.append({ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "read_resource", |
|
|
"description": "Read tool descriptions.", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": {"uri": {"type": "string"}}, |
|
|
"required": ["uri"] |
|
|
} |
|
|
} |
|
|
}) |
|
|
|
|
|
response = await self.openai.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=self.messages, |
|
|
tools=openai_tools, |
|
|
tool_choice="auto" |
|
|
) |
|
|
|
|
|
msg = response.choices[0].message |
|
|
self.messages.append(msg) |
|
|
|
|
|
if msg.content: |
|
|
yield msg.content, logs, self.metrics |
|
|
|
|
|
if not msg.tool_calls: |
|
|
break |
|
|
|
|
|
for tool_call in msg.tool_calls: |
|
|
fn_name = tool_call.function.name |
|
|
fn_args = json.loads(tool_call.function.arguments) |
|
|
|
|
|
log_entry = f"π οΈ **Tool Call:** `{fn_name}`" |
|
|
logs.append(log_entry) |
|
|
yield None, logs, self.metrics |
|
|
|
|
|
result_content = "" |
|
|
|
|
|
if fn_name == "read_resource": |
|
|
uri = fn_args.get("uri") |
|
|
|
|
|
uri_str = str(uri) if uri else "" |
|
|
logs.append(f"π₯ **Fetching:** `{uri_str}`") |
|
|
yield None, logs, self.metrics |
|
|
|
|
|
try: |
|
|
res = await self.session.read_resource(uri_str) |
|
|
content = res.contents[0].text |
|
|
tokens = count_tokens(content) |
|
|
self.metrics["runtime"] += tokens |
|
|
logs.append(f"π **Loaded:** {tokens} tokens") |
|
|
result_content = content |
|
|
except Exception as e: |
|
|
result_content = json.dumps({"error": str(e)}) |
|
|
logs.append(f"β **Error:** {e}") |
|
|
|
|
|
else: |
|
|
try: |
|
|
res = await self.session.call_tool(fn_name, fn_args) |
|
|
content = res.content[0].text |
|
|
if "TOOL_DESCRIPTION_REQUIRED" in content: |
|
|
logs.append("β οΈ **Auth Error:** Need fetch") |
|
|
else: |
|
|
logs.append("β
**Success**") |
|
|
result_content = content |
|
|
except Exception as e: |
|
|
result_content = str(e) |
|
|
logs.append(f"β **Error:** {e}") |
|
|
|
|
|
self.messages.append({ |
|
|
"role": "tool", |
|
|
"tool_call_id": tool_call.id, |
|
|
"content": result_content |
|
|
}) |
|
|
yield None, logs, self.metrics |
|
|
|
|
|
async def close(self): |
|
|
await self.stack.aclose() |
|
|
|
|
|
|
|
|
current_agent = None |
|
|
|
|
|
async def start_agent(mode, api_key): |
|
|
global current_agent |
|
|
|
|
|
|
|
|
|
|
|
if current_agent: |
|
|
try: |
|
|
await current_agent.close() |
|
|
except RuntimeError: |
|
|
|
|
|
pass |
|
|
except Exception as e: |
|
|
print(f"Warning closing old agent: {e}") |
|
|
|
|
|
try: |
|
|
current_agent = PersistentAgent(mode, api_key) |
|
|
await current_agent.initialize() |
|
|
return f"Started in {mode.upper()} mode.", [], current_agent.metrics |
|
|
except Exception as e: |
|
|
return f"Error starting agent: {str(e)}", [], {} |
|
|
|
|
|
async def process_message(message, history): |
|
|
if not current_agent: |
|
|
yield "Please enter your API key and click Start Server first.", [], {} |
|
|
return |
|
|
|
|
|
full_response = "" |
|
|
async for content, logs, metrics in current_agent.chat(message): |
|
|
if content: |
|
|
full_response += content |
|
|
|
|
|
|
|
|
|
|
|
yield full_response, logs, metrics |
|
|
else: |
|
|
yield full_response, logs, metrics |
|
|
|
|
|
|
|
|
with gr.Blocks(title="MCP Progressive Disclosure Demo") as demo: |
|
|
gr.Markdown("# MCP Progressive Disclosure Demo π") |
|
|
gr.Markdown("Compare the token usage between Standard (Load All) and Progressive (Lazy Load) MCP servers.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
api_key_input = gr.Textbox( |
|
|
label="OpenAI API Key", |
|
|
placeholder="sk-...", |
|
|
type="password", |
|
|
value=os.environ.get("OPENAI_API_KEY", "") |
|
|
) |
|
|
mode_radio = gr.Radio(["standard", "progressive"], label="Mode", value="standard") |
|
|
start_btn = gr.Button("Start/Restart Server") |
|
|
status_output = gr.Markdown("") |
|
|
|
|
|
metrics_json = gr.JSON(label="Token Metrics") |
|
|
logs_box = gr.JSON(label="Activity Logs") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
|
|
|
chatbot = gr.Chatbot(label="Conversation") |
|
|
msg = gr.Textbox(label="Your Message") |
|
|
clear = gr.Button("Clear") |
|
|
|
|
|
|
|
|
clear.click(lambda: [], None, chatbot, queue=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def on_start(mode, api_key): |
|
|
status, _, metrics = await start_agent(mode, api_key) |
|
|
return status, metrics, [] |
|
|
|
|
|
async def on_message(message, history): |
|
|
|
|
|
history = history or [] |
|
|
history.append({"role": "user", "content": message}) |
|
|
history.append({"role": "assistant", "content": ""}) |
|
|
|
|
|
async for content, logs, metrics in process_message(message, history): |
|
|
history[-1]["content"] = content |
|
|
yield "", history, metrics, logs |
|
|
|
|
|
|
|
|
start_btn.click(on_start, inputs=[mode_radio, api_key_input], outputs=[status_output, metrics_json, logs_box]) |
|
|
|
|
|
msg.submit(on_message, inputs=[msg, chatbot], outputs=[msg, chatbot, metrics_json, logs_box]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue() |
|
|
demo.launch(inbrowser=False) |
|
|
|
|
|
|