|
|
import sys |
|
|
import torch |
|
|
import json |
|
|
from chemietoolkit import ChemIEToolkit,utils |
|
|
import cv2 |
|
|
from openai import AzureOpenAI |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import json |
|
|
from get_molecular_agent import process_reaction_image_with_multiple_products_and_text_correctR |
|
|
from get_reaction_agent import get_reaction_withatoms_correctR |
|
|
from get_R_group_sub_agent import process_reaction_image_with_table_R_group, process_reaction_image_with_product_variant_R_group,get_full_reaction,get_multi_molecular_full |
|
|
import os |
|
|
import sys |
|
|
from rxnim import RxnScribe |
|
|
import json |
|
|
import base64 |
|
|
model = ChemIEToolkit(device=torch.device('cpu')) |
|
|
ckpt_path = "./pix2seq_reaction_full.ckpt" |
|
|
model1 = RxnScribe(ckpt_path, device=torch.device('cpu')) |
|
|
device = torch.device('cpu') |
|
|
|
|
|
|
|
|
def ChemEagle(image_path: str) -> dict: |
|
|
""" |
|
|
输入化学反应图像路径,通过 GPT 模型和 TOOLS 提取反应信息并返回整理后的反应数据。 |
|
|
|
|
|
Args: |
|
|
image_path (str): 图像文件路径。 |
|
|
|
|
|
Returns: |
|
|
dict: 整理后的反应数据,包括反应物、产物和反应模板。 |
|
|
""" |
|
|
|
|
|
API_KEY = os.getenv("API_KEY") |
|
|
AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT") |
|
|
|
|
|
client = AzureOpenAI( |
|
|
api_key=API_KEY, |
|
|
api_version='2024-06-01', |
|
|
azure_endpoint=AZURE_ENDPOINT |
|
|
) |
|
|
|
|
|
|
|
|
def encode_image(image_path: str): |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
base64_image = encode_image(image_path) |
|
|
|
|
|
|
|
|
tools = [ |
|
|
{ |
|
|
'type': 'function', |
|
|
'function': { |
|
|
'name': 'process_reaction_image_with_product_variant_R_group', |
|
|
'description': 'get the reaction data of the reaction diagram and get SMILES strings of every detailed reaction in reaction diagram and the set of product variants, and the original molecular list.', |
|
|
'parameters': { |
|
|
'type': 'object', |
|
|
'properties': { |
|
|
'image_path': { |
|
|
'type': 'string', |
|
|
'description': 'The path to the reaction image.', |
|
|
}, |
|
|
}, |
|
|
'required': ['image_path'], |
|
|
'additionalProperties': False, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
'type': 'function', |
|
|
'function': { |
|
|
'name': 'process_reaction_image_with_table_R_group', |
|
|
'description': 'get the reaction data of the reaction diagram and get SMILES strings of every detailed reaction in reaction diagram and the R-group table', |
|
|
'parameters': { |
|
|
'type': 'object', |
|
|
'properties': { |
|
|
'image_path': { |
|
|
'type': 'string', |
|
|
'description': 'The path to the reaction image.', |
|
|
}, |
|
|
}, |
|
|
'required': ['image_path'], |
|
|
'additionalProperties': False, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
'type': 'function', |
|
|
'function': { |
|
|
'name': 'get_full_reaction', |
|
|
'description': 'After you carefully check the image, if this is a reaction image that contains only a text-based table and does not involve any R-group replacement, or this is a reaction image does not contain any tables or sets of product variants, then just call this simplified tool.', |
|
|
'parameters': { |
|
|
'type': 'object', |
|
|
'properties': { |
|
|
'image_path': { |
|
|
'type': 'string', |
|
|
'description': 'The path to the reaction image.', |
|
|
}, |
|
|
}, |
|
|
'required': ['image_path'], |
|
|
'additionalProperties': False, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
{ |
|
|
'type': 'function', |
|
|
'function': { |
|
|
'name': 'get_multi_molecular_full', |
|
|
'description': 'After you carefully check the image, if this is a single molecule image or a multiple molecules image, then need to call this molecular recognition tool.', |
|
|
'parameters': { |
|
|
'type': 'object', |
|
|
'properties': { |
|
|
'image_path': { |
|
|
'type': 'string', |
|
|
'description': 'The path to the reaction image.', |
|
|
}, |
|
|
}, |
|
|
'required': ['image_path'], |
|
|
'additionalProperties': False, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
with open('./prompt/prompt_final_simple_version.txt', 'r') as prompt_file: |
|
|
prompt = prompt_file.read() |
|
|
with open('./prompt/prompt_plan.txt', 'r') as prompt_file: |
|
|
prompt_plan = prompt_file.read() |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{'type': 'text', 'text': prompt_plan}, |
|
|
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}}, |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model = 'gpt-4o', |
|
|
temperature = 0, |
|
|
response_format={ 'type': 'json_object' }, |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt_plan |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
]}, |
|
|
], |
|
|
tools = tools) |
|
|
|
|
|
|
|
|
TOOL_MAP = { |
|
|
'process_reaction_image_with_product_variant_R_group': process_reaction_image_with_product_variant_R_group, |
|
|
'process_reaction_image_with_table_R_group': process_reaction_image_with_table_R_group, |
|
|
'get_full_reaction': get_full_reaction, |
|
|
'get_multi_molecular_full': get_multi_molecular_full |
|
|
} |
|
|
|
|
|
|
|
|
tool_calls = response.choices[0].message.tool_calls |
|
|
print(f"tool_calls:{tool_calls}") |
|
|
results = [] |
|
|
|
|
|
|
|
|
for tool_call in tool_calls: |
|
|
tool_name = tool_call.function.name |
|
|
tool_arguments = tool_call.function.arguments |
|
|
tool_call_id = tool_call.id |
|
|
|
|
|
tool_args = json.loads(tool_arguments) |
|
|
|
|
|
if tool_name in TOOL_MAP: |
|
|
|
|
|
tool_result = TOOL_MAP[tool_name](image_path) |
|
|
else: |
|
|
raise ValueError(f"Unknown tool called: {tool_name}") |
|
|
|
|
|
|
|
|
results.append({ |
|
|
'role': 'tool', |
|
|
'content': json.dumps({ |
|
|
'image_path': image_path, |
|
|
f'{tool_name}':(tool_result), |
|
|
}), |
|
|
'tool_call_id': tool_call_id, |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
completion_payload = { |
|
|
'model': 'gpt-4o', |
|
|
'messages': [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
] |
|
|
}, |
|
|
response.choices[0].message, |
|
|
*results |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model=completion_payload["model"], |
|
|
messages=completion_payload["messages"], |
|
|
response_format={ 'type': 'json_object' }, |
|
|
temperature=0 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gpt_output = json.loads(response.choices[0].message.content) |
|
|
print(gpt_output) |
|
|
return gpt_output |
|
|
|
|
|
if __name__ == "__main__": |
|
|
model = ChemEagle() |