|
|
import sys |
|
|
import torch |
|
|
import json |
|
|
from chemietoolkit import ChemIEToolkit |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import json |
|
|
import sys |
|
|
|
|
|
import torch |
|
|
from rxnim import RxnScribe |
|
|
import json |
|
|
from molnextr.chemistry import _convert_graph_to_smiles |
|
|
|
|
|
from openai import AzureOpenAI |
|
|
import base64 |
|
|
import numpy as np |
|
|
from chemietoolkit import utils |
|
|
from PIL import Image |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
ckpt_path = "./pix2seq_reaction_full.ckpt" |
|
|
model1 = RxnScribe(ckpt_path, device=torch.device('cpu')) |
|
|
device = torch.device('cpu') |
|
|
model = ChemIEToolkit(device=torch.device('cpu')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_reaction(image_path: str) -> dict: |
|
|
''' |
|
|
Returns a structured dictionary of reactions extracted from the image, |
|
|
including reactants, conditions, and products, with their smiles, text, and bbox. |
|
|
''' |
|
|
image_file = image_path |
|
|
raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True) |
|
|
|
|
|
|
|
|
structured_output = {} |
|
|
for section_key in ['reactants', 'conditions', 'products']: |
|
|
if section_key in raw_prediction[0]: |
|
|
structured_output[section_key] = [] |
|
|
for item in raw_prediction[0][section_key]: |
|
|
if section_key in ['reactants', 'products']: |
|
|
|
|
|
structured_output[section_key].append({ |
|
|
"smiles": item.get("smiles", ""), |
|
|
"bbox": item.get("bbox", []), |
|
|
"symbols": item.get("symbols", []) |
|
|
}) |
|
|
elif section_key == 'conditions': |
|
|
|
|
|
condition_data = {"bbox": item.get("bbox", [])} |
|
|
if "smiles" in item: |
|
|
condition_data["smiles"] = item.get("smiles", "") |
|
|
if "text" in item: |
|
|
condition_data["text"] = item.get("text", []) |
|
|
structured_output[section_key].append(condition_data) |
|
|
|
|
|
|
|
|
return structured_output |
|
|
|
|
|
|
|
|
|
|
|
def get_full_reaction(image_path: str) -> dict: |
|
|
''' |
|
|
Returns a structured dictionary of reactions extracted from the image, |
|
|
including reactants, conditions, and products, with their smiles, text, and bbox. |
|
|
''' |
|
|
image_file = image_path |
|
|
raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True) |
|
|
for reaction in raw_prediction: |
|
|
for section in ("reactants", "products", "conditions"): |
|
|
for entry in reaction.get(section, []): |
|
|
|
|
|
coords = entry.get("coords") |
|
|
if isinstance(coords, list): |
|
|
entry["coords"] = [ |
|
|
[round(val, 3) for val in point] |
|
|
for point in coords |
|
|
] |
|
|
|
|
|
for key in ("molfile", "atoms", "bonds"): |
|
|
entry.pop(key, None) |
|
|
|
|
|
raw_prediction =json.dumps(raw_prediction) |
|
|
return raw_prediction |
|
|
|
|
|
|
|
|
|
|
|
def get_reaction_withatoms(image_path: str) -> dict: |
|
|
""" |
|
|
输入化学反应图像路径,通过 GPT 模型和 OpenChemIE 提取反应信息并返回整理后的反应数据。 |
|
|
|
|
|
Args: |
|
|
image_path (str): 图像文件路径。 |
|
|
|
|
|
Returns: |
|
|
dict: 整理后的反应数据,包括反应物、产物和反应模板。 |
|
|
""" |
|
|
|
|
|
API_KEY = os.getenv("API_KEY") |
|
|
AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT") |
|
|
client = AzureOpenAI( |
|
|
api_key=API_KEY, |
|
|
api_version='2024-06-01', |
|
|
azure_endpoint=AZURE_ENDPOINT |
|
|
) |
|
|
|
|
|
|
|
|
def encode_image(image_path: str): |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
base64_image = encode_image(image_path) |
|
|
|
|
|
|
|
|
tools = [ |
|
|
{ |
|
|
'type': 'function', |
|
|
'function': { |
|
|
'name': 'get_reaction', |
|
|
'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.', |
|
|
'parameters': { |
|
|
'type': 'object', |
|
|
'properties': { |
|
|
'image_path': { |
|
|
'type': 'string', |
|
|
'description': 'The path to the reaction image.', |
|
|
}, |
|
|
}, |
|
|
'required': ['image_path'], |
|
|
'additionalProperties': False, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
with open('./prompt/prompt_getreaction.txt', 'r') as prompt_file: |
|
|
prompt = prompt_file.read() |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{'type': 'text', 'text': prompt}, |
|
|
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model = 'gpt-4o', |
|
|
temperature = 0, |
|
|
response_format={ 'type': 'json_object' }, |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
]}, |
|
|
], |
|
|
tools = tools) |
|
|
|
|
|
|
|
|
TOOL_MAP = { |
|
|
'get_reaction': get_reaction, |
|
|
} |
|
|
|
|
|
|
|
|
tool_calls = response.choices[0].message.tool_calls |
|
|
results = [] |
|
|
|
|
|
|
|
|
for tool_call in tool_calls: |
|
|
tool_name = tool_call.function.name |
|
|
tool_arguments = tool_call.function.arguments |
|
|
tool_call_id = tool_call.id |
|
|
|
|
|
tool_args = json.loads(tool_arguments) |
|
|
|
|
|
if tool_name in TOOL_MAP: |
|
|
|
|
|
tool_result = TOOL_MAP[tool_name](image_path) |
|
|
else: |
|
|
raise ValueError(f"Unknown tool called: {tool_name}") |
|
|
|
|
|
|
|
|
results.append({ |
|
|
'role': 'tool', |
|
|
'content': json.dumps({ |
|
|
'image_path': image_path, |
|
|
f'{tool_name}':(tool_result), |
|
|
}), |
|
|
'tool_call_id': tool_call_id, |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
completion_payload = { |
|
|
'model': 'gpt-4o', |
|
|
'messages': [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
] |
|
|
}, |
|
|
response.choices[0].message, |
|
|
*results |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model=completion_payload["model"], |
|
|
messages=completion_payload["messages"], |
|
|
response_format={ 'type': 'json_object' }, |
|
|
temperature=0 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gpt_output = json.loads(response.choices[0].message.content) |
|
|
|
|
|
|
|
|
|
|
|
def get_reaction_full(image_path: str) -> dict: |
|
|
''' |
|
|
Returns a structured dictionary of reactions extracted from the image, |
|
|
including reactants, conditions, and products, with their smiles, text, and bbox. |
|
|
''' |
|
|
image_file = image_path |
|
|
raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True) |
|
|
return raw_prediction |
|
|
|
|
|
input2 = get_reaction_full(image_path) |
|
|
|
|
|
|
|
|
|
|
|
def update_input_with_symbols(input1, input2, conversion_function): |
|
|
symbol_mapping = {} |
|
|
for key in ['reactants', 'products']: |
|
|
for item in input1.get(key, []): |
|
|
bbox = tuple(item['bbox']) |
|
|
symbol_mapping[bbox] = item['symbols'] |
|
|
|
|
|
for key in ['reactants', 'products']: |
|
|
for item in input2.get(key, []): |
|
|
bbox = tuple(item['bbox']) |
|
|
|
|
|
|
|
|
if bbox in symbol_mapping: |
|
|
updated_symbols = symbol_mapping[bbox] |
|
|
item['symbols'] = updated_symbols |
|
|
|
|
|
|
|
|
if 'atoms' in item: |
|
|
atoms = item['atoms'] |
|
|
if len(atoms) != len(updated_symbols): |
|
|
print(f"Warning: Mismatched symbols and atoms in bbox {bbox}") |
|
|
else: |
|
|
for atom, symbol in zip(atoms, updated_symbols): |
|
|
atom['atom_symbol'] = symbol |
|
|
|
|
|
|
|
|
if 'coords' in item and 'edges' in item: |
|
|
coords = item['coords'] |
|
|
edges = item['edges'] |
|
|
new_smiles, new_molfile, _ = conversion_function(coords, updated_symbols, edges) |
|
|
|
|
|
|
|
|
item['smiles'] = new_smiles |
|
|
item['molfile'] = new_molfile |
|
|
|
|
|
return input2 |
|
|
|
|
|
updated_data = [update_input_with_symbols(gpt_output, input2[0], _convert_graph_to_smiles)] |
|
|
|
|
|
return updated_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_reaction_withatoms_correctR(image_path: str) -> dict: |
|
|
""" |
|
|
输入化学反应图像路径,通过 GPT 模型和 OpenChemIE 提取反应信息并返回整理后的反应数据。 |
|
|
|
|
|
Args: |
|
|
image_path (str): 图像文件路径。 |
|
|
|
|
|
Returns: |
|
|
dict: 整理后的反应数据,包括反应物、产物和反应模板。 |
|
|
""" |
|
|
|
|
|
API_KEY = os.getenv("API_KEY") |
|
|
AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT") |
|
|
|
|
|
client = AzureOpenAI( |
|
|
api_key=API_KEY, |
|
|
api_version='2024-06-01', |
|
|
azure_endpoint=AZURE_ENDPOINT |
|
|
) |
|
|
|
|
|
|
|
|
def encode_image(image_path: str): |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
base64_image = encode_image(image_path) |
|
|
|
|
|
|
|
|
tools = [ |
|
|
{ |
|
|
'type': 'function', |
|
|
'function': { |
|
|
'name': 'get_reaction', |
|
|
'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.', |
|
|
'parameters': { |
|
|
'type': 'object', |
|
|
'properties': { |
|
|
'image_path': { |
|
|
'type': 'string', |
|
|
'description': 'The path to the reaction image.', |
|
|
}, |
|
|
}, |
|
|
'required': ['image_path'], |
|
|
'additionalProperties': False, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
with open('./prompt/prompt_getreaction_correctR.txt', 'r') as prompt_file: |
|
|
prompt = prompt_file.read() |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{'type': 'text', 'text': prompt}, |
|
|
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model = 'gpt-4o', |
|
|
temperature = 0, |
|
|
response_format={ 'type': 'json_object' }, |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
]}, |
|
|
], |
|
|
tools = tools) |
|
|
|
|
|
|
|
|
TOOL_MAP = { |
|
|
'get_reaction': get_reaction, |
|
|
} |
|
|
|
|
|
|
|
|
tool_calls = response.choices[0].message.tool_calls |
|
|
results = [] |
|
|
|
|
|
|
|
|
for tool_call in tool_calls: |
|
|
tool_name = tool_call.function.name |
|
|
tool_arguments = tool_call.function.arguments |
|
|
tool_call_id = tool_call.id |
|
|
|
|
|
tool_args = json.loads(tool_arguments) |
|
|
|
|
|
if tool_name in TOOL_MAP: |
|
|
|
|
|
tool_result = TOOL_MAP[tool_name](image_path) |
|
|
else: |
|
|
raise ValueError(f"Unknown tool called: {tool_name}") |
|
|
|
|
|
|
|
|
results.append({ |
|
|
'role': 'tool', |
|
|
'content': json.dumps({ |
|
|
'image_path': image_path, |
|
|
f'{tool_name}':(tool_result), |
|
|
}), |
|
|
'tool_call_id': tool_call_id, |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
completion_payload = { |
|
|
'model': 'gpt-4o', |
|
|
'messages': [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
] |
|
|
}, |
|
|
response.choices[0].message, |
|
|
*results |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model=completion_payload["model"], |
|
|
messages=completion_payload["messages"], |
|
|
response_format={ 'type': 'json_object' }, |
|
|
temperature=0 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gpt_output = json.loads(response.choices[0].message.content) |
|
|
|
|
|
|
|
|
|
|
|
def get_reaction_full(image_path: str) -> dict: |
|
|
''' |
|
|
Returns a structured dictionary of reactions extracted from the image, |
|
|
including reactants, conditions, and products, with their smiles, text, and bbox. |
|
|
''' |
|
|
image_file = image_path |
|
|
raw_prediction = model1.predict_image_file(image_file, molscribe=True, ocr=True) |
|
|
return raw_prediction |
|
|
|
|
|
input2 = get_reaction_full(image_path) |
|
|
|
|
|
|
|
|
|
|
|
def update_input_with_symbols(input1, input2, conversion_function): |
|
|
symbol_mapping = {} |
|
|
for key in ['reactants', 'products']: |
|
|
for item in input1.get(key, []): |
|
|
bbox = tuple(item['bbox']) |
|
|
symbol_mapping[bbox] = item['symbols'] |
|
|
|
|
|
for key in ['reactants', 'products']: |
|
|
for item in input2.get(key, []): |
|
|
bbox = tuple(item['bbox']) |
|
|
|
|
|
|
|
|
if bbox in symbol_mapping: |
|
|
updated_symbols = symbol_mapping[bbox] |
|
|
item['symbols'] = updated_symbols |
|
|
|
|
|
|
|
|
if 'atoms' in item: |
|
|
atoms = item['atoms'] |
|
|
if len(atoms) != len(updated_symbols): |
|
|
print(f"Warning: Mismatched symbols and atoms in bbox {bbox}") |
|
|
else: |
|
|
for atom, symbol in zip(atoms, updated_symbols): |
|
|
atom['atom_symbol'] = symbol |
|
|
|
|
|
|
|
|
if 'coords' in item and 'edges' in item: |
|
|
coords = item['coords'] |
|
|
edges = item['edges'] |
|
|
new_smiles, new_molfile, _ = conversion_function(coords, updated_symbols, edges) |
|
|
|
|
|
|
|
|
item['smiles'] = new_smiles |
|
|
item['molfile'] = new_molfile |
|
|
|
|
|
return input2 |
|
|
|
|
|
updated_data = [update_input_with_symbols(gpt_output, input2[0], _convert_graph_to_smiles)] |
|
|
print(f"rxn_agent_output:{updated_data}") |
|
|
|
|
|
return updated_data |