Spaces:
Running
Running
| import dotenv | |
| dotenv.load_dotenv() | |
| import json | |
| import os | |
| import random | |
| import threading | |
| import time | |
| from toolformers.base import Tool, parameter_from_openai_api, StringParameter | |
| from toolformers.base import Toolformer | |
| from toolformers.camel import make_openai_toolformer | |
| from toolformers.langchain_agent import LangChainAnthropicToolformer | |
| from toolformers.sambanova import SambanovaToolformer | |
| from toolformers.gemini import GeminiToolformer | |
| from querier import Querier | |
| from responder import Responder | |
| from negotiator import SenderNegotiator, ReceiverNegotiator | |
| from programmer import SenderProgrammer, ReceiverProgrammer | |
| from executor import UnsafeExecutor | |
| from utils import compute_hash | |
| def create_toolformer(model_name) -> Toolformer: | |
| if model_name in ['gpt-4o', 'gpt-4o-mini']: | |
| return make_openai_toolformer(model_name) | |
| elif 'claude' in model_name: | |
| return LangChainAnthropicToolformer(model_name, os.environ.get('ANTHROPIC_API_KEY')) | |
| elif model_name in ['llama3-405b']: | |
| return SambanovaToolformer(model_name) | |
| elif model_name in ['gemini-1.5-pro']: | |
| return GeminiToolformer(model_name) | |
| else: | |
| raise ValueError(f"Unknown model name: {model_name}") | |
| def full_flow(schema, alice_model, bob_model): | |
| NL_MESSAGES = [] | |
| NEGOTIATION_MESSAGES = [] | |
| STRUCTURED_MESSAGES = [] | |
| ARTIFACTS = {} | |
| toolformer_alice = create_toolformer(alice_model) | |
| toolformer_bob = create_toolformer(bob_model) | |
| querier = Querier(toolformer_alice) | |
| responder = Responder(toolformer_bob) | |
| tools = [] | |
| for tool_schema in schema['tools']: | |
| parameters = [parameter_from_openai_api(name, schema, name in tool_schema['input']['required']) for name, schema in tool_schema['input']['properties'].items()] | |
| def tool_fn(*args, **kwargs): | |
| print(f'Bob tool {tool_schema["name"]} called with args {args} and kwargs {kwargs}') | |
| return random.choice(tool_schema['dummy_outputs']) | |
| tool = Tool(tool_schema['name'], tool_schema['description'], parameters, tool_fn, tool_schema['output']) | |
| tools.append(tool) | |
| def nl_callback_fn(query): | |
| print(query) | |
| NL_MESSAGES.append({ | |
| 'role': 'assistant', | |
| #'content': query['body'], | |
| 'body': query['body'], | |
| 'protocolHash': None | |
| }) | |
| response = responder.reply_to_query(query['body'], query['protocolHash'], tools, '') | |
| NL_MESSAGES.append({ | |
| 'role': 'user', | |
| #'content': response['body'] | |
| 'status': 'success', | |
| 'body': response['body'] | |
| }) | |
| return response | |
| negotiator_sender = SenderNegotiator(toolformer_alice) | |
| negotiator_receiver = ReceiverNegotiator(toolformer_bob, tools, '') | |
| def negotiation_callback_fn(query): | |
| print(query) | |
| NEGOTIATION_MESSAGES.append({ | |
| 'role': 'assistant', | |
| 'content': query | |
| }) | |
| response = negotiator_receiver.handle_negotiation(query) | |
| NEGOTIATION_MESSAGES.append({ | |
| 'role': 'user', | |
| 'content': response | |
| }) | |
| #print('CURRENT NEGOTIATION MESSAGES:', len(NEGOTIATION_MESSAGES)) | |
| return response | |
| def final_message_callback_fn(query): | |
| NEGOTIATION_MESSAGES.append({ | |
| 'role': 'assistant', | |
| 'content': query | |
| }) | |
| sender_programmer = SenderProgrammer(toolformer_alice) | |
| receiver_programmer = ReceiverProgrammer(toolformer_bob) | |
| executor = UnsafeExecutor() | |
| def structured_callback_fn(query): | |
| STRUCTURED_MESSAGES.append({ | |
| 'role': 'assistant', | |
| #'content': query | |
| 'body': json.dumps(query) if isinstance(query, dict) else query, | |
| 'protocolHash': ARTIFACTS['protocol']['hash'], | |
| 'protocolSources': ['https://...'] | |
| }) | |
| try: | |
| response = executor.run_routine(ARTIFACTS['protocol']['hash'], ARTIFACTS['implementation_receiver'], query, tools) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| STRUCTURED_MESSAGES.append({ | |
| 'role': 'user', | |
| 'status': 'error', | |
| 'message': str(e) | |
| }) | |
| return 'Error' | |
| STRUCTURED_MESSAGES.append({ | |
| 'role': 'user', | |
| #'content': response | |
| 'status': 'success', | |
| 'body': json.dumps(response) if isinstance(response, dict) else response | |
| }) | |
| return response | |
| def flow(): | |
| task_data = random.choice(schema['examples']) | |
| querier.send_query_without_protocol(schema, task_data, nl_callback_fn) | |
| #time.sleep(1) | |
| res = negotiator_sender.negotiate_protocol_for_task(schema, negotiation_callback_fn, final_message_callback_fn=final_message_callback_fn) | |
| protocol_hash = compute_hash(res['protocol']) | |
| res['hash'] = protocol_hash | |
| ARTIFACTS['protocol'] = res | |
| protocol_document = res['protocol'] | |
| implementation_sender = sender_programmer.write_routine_for_task(schema, protocol_document) | |
| ARTIFACTS['implementation_sender'] = implementation_sender | |
| implementation_receiver = receiver_programmer.write_routine_for_tools(tools, protocol_document, '') | |
| ARTIFACTS['implementation_receiver'] = implementation_receiver | |
| send_tool = Tool('send_to_server', 'Send to server', StringParameter('query', 'The query', True), structured_callback_fn) | |
| try: | |
| executor.run_routine(protocol_hash, implementation_sender, task_data, [send_tool]) | |
| except Exception as e: | |
| # Print the error | |
| import traceback | |
| traceback.print_exc() | |
| STRUCTURED_MESSAGES.append({ | |
| 'role': 'assistant', | |
| 'status': 'error', | |
| 'message': str(e) | |
| }) | |
| def get_info(): | |
| return NL_MESSAGES, NEGOTIATION_MESSAGES, STRUCTURED_MESSAGES, ARTIFACTS.get('protocol', {}).get('protocol', ''), \ | |
| ARTIFACTS.get('implementation_sender', ''), ARTIFACTS.get('implementation_receiver', '') | |
| thread = threading.Thread( | |
| target = lambda: flow() | |
| ) | |
| thread.start() | |
| while thread.is_alive(): | |
| yield get_info() | |
| time.sleep(0.2) | |
| yield get_info() |