|
|
import sys |
|
|
import torch |
|
|
import json |
|
|
from chemietoolkit import ChemIEToolkit |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import json |
|
|
import sys |
|
|
import torch |
|
|
from rxnim import RxnScribe |
|
|
import json |
|
|
import sys |
|
|
import torch |
|
|
import json |
|
|
from molnextr.chemistry import _convert_graph_to_smiles |
|
|
import base64 |
|
|
import torch |
|
|
import json |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from chemietoolkit import ChemIEToolkit, utils |
|
|
from openai import AzureOpenAI |
|
|
import os |
|
|
import copy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ckpt_path = "./pix2seq_reaction_full.ckpt" |
|
|
model1 = RxnScribe(ckpt_path, device=torch.device('cpu')) |
|
|
device = torch.device('cpu') |
|
|
model = ChemIEToolkit(device=torch.device('cpu')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_multi_molecular(image_path: str) -> list: |
|
|
'''Returns a list of reactions extracted from the image.''' |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
|
|
|
coref_results = model.extract_molecule_corefs_from_figures([image]) |
|
|
|
|
|
for item in coref_results: |
|
|
for bbox in item.get("bboxes", []): |
|
|
for key in ["category", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: |
|
|
bbox.pop(key, None) |
|
|
|
|
|
|
|
|
|
|
|
return json.dumps(coref_results) |
|
|
|
|
|
def get_multi_molecular_text_to_correct(image_path: str) -> list: |
|
|
'''Returns a list of reactions extracted from the image.''' |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
|
|
|
coref_results = model.extract_molecule_corefs_from_figures([image]) |
|
|
for item in coref_results: |
|
|
for bbox in item.get("bboxes", []): |
|
|
for key in ["category", "bbox", "molfile", "symbols", 'atoms', "bonds", 'category_id', 'score', 'corefs']: |
|
|
bbox.pop(key, None) |
|
|
|
|
|
|
|
|
|
|
|
return json.dumps(coref_results) |
|
|
|
|
|
def get_multi_molecular_text_to_correct_withatoms(image_path: str) -> list: |
|
|
'''Returns a list of reactions extracted from the image.''' |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
|
|
|
coref_results = model.extract_molecule_corefs_from_figures([image]) |
|
|
for item in coref_results: |
|
|
for bbox in item.get("bboxes", []): |
|
|
for key in ["coords","edges","molfile", 'atoms', "bonds", 'category_id', 'score', 'corefs']: |
|
|
bbox.pop(key, None) |
|
|
|
|
|
|
|
|
return json.dumps(coref_results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_reaction_image_with_multiple_products_and_text(image_path: str) -> dict: |
|
|
""" |
|
|
|
|
|
|
|
|
Args: |
|
|
image_path (str): 图像文件路径。 |
|
|
|
|
|
Returns: |
|
|
dict: 整理后的反应数据,包括反应物、产物和反应模板。 |
|
|
""" |
|
|
|
|
|
client = AzureOpenAI( |
|
|
api_key=API_KEY, |
|
|
api_version='2024-06-01', |
|
|
azure_endpoint=AZURE_ENDPOINT |
|
|
) |
|
|
|
|
|
|
|
|
def encode_image(image_path: str): |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
base64_image = encode_image(image_path) |
|
|
|
|
|
|
|
|
tools = [ |
|
|
{ |
|
|
'type': 'function', |
|
|
'function': { |
|
|
'name': 'get_multi_molecular_text_to_correct_withatoms', |
|
|
'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.', |
|
|
'parameters': { |
|
|
'type': 'object', |
|
|
'properties': { |
|
|
'image_path': { |
|
|
'type': 'string', |
|
|
'description': 'The path to the reaction image.', |
|
|
}, |
|
|
}, |
|
|
'required': ['image_path'], |
|
|
'additionalProperties': False, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
with open('./prompt/prompt_getmolecular.txt', 'r') as prompt_file: |
|
|
prompt = prompt_file.read() |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{'type': 'text', 'text': prompt}, |
|
|
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model = 'gpt-4o', |
|
|
temperature = 0, |
|
|
response_format={ 'type': 'json_object' }, |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
]}, |
|
|
], |
|
|
tools = tools) |
|
|
|
|
|
|
|
|
TOOL_MAP = { |
|
|
'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms, |
|
|
} |
|
|
|
|
|
|
|
|
tool_calls = response.choices[0].message.tool_calls |
|
|
results = [] |
|
|
|
|
|
|
|
|
for tool_call in tool_calls: |
|
|
tool_name = tool_call.function.name |
|
|
tool_arguments = tool_call.function.arguments |
|
|
tool_call_id = tool_call.id |
|
|
|
|
|
tool_args = json.loads(tool_arguments) |
|
|
|
|
|
if tool_name in TOOL_MAP: |
|
|
|
|
|
tool_result = TOOL_MAP[tool_name](image_path) |
|
|
else: |
|
|
raise ValueError(f"Unknown tool called: {tool_name}") |
|
|
|
|
|
|
|
|
results.append({ |
|
|
'role': 'tool', |
|
|
'content': json.dumps({ |
|
|
'image_path': image_path, |
|
|
f'{tool_name}':(tool_result), |
|
|
}), |
|
|
'tool_call_id': tool_call_id, |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
completion_payload = { |
|
|
'model': 'gpt-4o', |
|
|
'messages': [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
] |
|
|
}, |
|
|
response.choices[0].message, |
|
|
*results |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model=completion_payload["model"], |
|
|
messages=completion_payload["messages"], |
|
|
response_format={ 'type': 'json_object' }, |
|
|
temperature=0 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gpt_output = [json.loads(response.choices[0].message.content)] |
|
|
|
|
|
|
|
|
def get_multi_molecular(image_path: str) -> list: |
|
|
'''Returns a list of reactions extracted from the image.''' |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
|
|
|
coref_results = model.extract_molecule_corefs_from_figures([image]) |
|
|
return coref_results |
|
|
|
|
|
|
|
|
coref_results = get_multi_molecular(image_path) |
|
|
|
|
|
|
|
|
def update_symbols_in_atoms(input1, input2): |
|
|
""" |
|
|
用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。 |
|
|
假设 input1 和 input2 的结构一致。 |
|
|
""" |
|
|
for item1, item2 in zip(input1, input2): |
|
|
bboxes1 = item1.get('bboxes', []) |
|
|
bboxes2 = item2.get('bboxes', []) |
|
|
|
|
|
if len(bboxes1) != len(bboxes2): |
|
|
print("Warning: Mismatched number of bboxes!") |
|
|
continue |
|
|
|
|
|
for bbox1, bbox2 in zip(bboxes1, bboxes2): |
|
|
|
|
|
if 'symbols' in bbox1: |
|
|
bbox2['symbols'] = bbox1['symbols'] |
|
|
|
|
|
|
|
|
if 'symbols' in bbox1 and 'atoms' in bbox2: |
|
|
symbols = bbox1['symbols'] |
|
|
atoms = bbox2.get('atoms', []) |
|
|
|
|
|
|
|
|
if len(symbols) != len(atoms): |
|
|
print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!") |
|
|
continue |
|
|
|
|
|
for atom, symbol in zip(atoms, symbols): |
|
|
atom['atom_symbol'] = symbol |
|
|
|
|
|
return input2 |
|
|
|
|
|
|
|
|
input2_updated = update_symbols_in_atoms(gpt_output, coref_results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_smiles_and_molfile(input_data, conversion_function): |
|
|
""" |
|
|
使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile', |
|
|
并替换到原数据结构中。 |
|
|
|
|
|
参数: |
|
|
- input_data: 包含 bboxes 的嵌套数据结构 |
|
|
- conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _) |
|
|
|
|
|
返回: |
|
|
- 更新后的数据结构 |
|
|
""" |
|
|
for item in input_data: |
|
|
for bbox in item.get('bboxes', []): |
|
|
|
|
|
if all(key in bbox for key in ['coords', 'symbols', 'edges']): |
|
|
coords = bbox['coords'] |
|
|
symbols = bbox['symbols'] |
|
|
edges = bbox['edges'] |
|
|
|
|
|
|
|
|
new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges) |
|
|
|
|
|
|
|
|
|
|
|
bbox['smiles'] = new_smiles |
|
|
bbox['molfile'] = new_molfile |
|
|
|
|
|
return input_data |
|
|
|
|
|
updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles) |
|
|
|
|
|
return updated_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_reaction_image_with_multiple_products_and_text_correctR(image_path: str) -> dict: |
|
|
""" |
|
|
|
|
|
|
|
|
Args: |
|
|
image_path (str): 图像文件路径。 |
|
|
|
|
|
Returns: |
|
|
dict: 整理后的反应数据,包括反应物、产物和反应模板。 |
|
|
""" |
|
|
API_KEY = os.getenv("API_KEY") |
|
|
AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT") |
|
|
client = AzureOpenAI( |
|
|
api_key=API_KEY, |
|
|
api_version='2024-06-01', |
|
|
azure_endpoint=AZURE_ENDPOINT |
|
|
) |
|
|
|
|
|
|
|
|
def encode_image(image_path: str): |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
base64_image = encode_image(image_path) |
|
|
|
|
|
|
|
|
tools = [ |
|
|
{ |
|
|
'type': 'function', |
|
|
'function': { |
|
|
'name': 'get_multi_molecular_text_to_correct_withatoms', |
|
|
'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.', |
|
|
'parameters': { |
|
|
'type': 'object', |
|
|
'properties': { |
|
|
'image_path': { |
|
|
'type': 'string', |
|
|
'description': 'The path to the reaction image.', |
|
|
}, |
|
|
}, |
|
|
'required': ['image_path'], |
|
|
'additionalProperties': False, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
with open('./prompt/prompt_getmolecular_correctR.txt', 'r') as prompt_file: |
|
|
prompt = prompt_file.read() |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{'type': 'text', 'text': prompt}, |
|
|
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model = 'gpt-4o', |
|
|
temperature = 0, |
|
|
response_format={ 'type': 'json_object' }, |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
]}, |
|
|
], |
|
|
tools = tools) |
|
|
|
|
|
|
|
|
TOOL_MAP = { |
|
|
'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms, |
|
|
} |
|
|
|
|
|
|
|
|
tool_calls = response.choices[0].message.tool_calls |
|
|
results = [] |
|
|
|
|
|
|
|
|
for tool_call in tool_calls: |
|
|
tool_name = tool_call.function.name |
|
|
tool_arguments = tool_call.function.arguments |
|
|
tool_call_id = tool_call.id |
|
|
|
|
|
tool_args = json.loads(tool_arguments) |
|
|
|
|
|
if tool_name in TOOL_MAP: |
|
|
|
|
|
tool_result = TOOL_MAP[tool_name](image_path) |
|
|
else: |
|
|
raise ValueError(f"Unknown tool called: {tool_name}") |
|
|
|
|
|
|
|
|
results.append({ |
|
|
'role': 'tool', |
|
|
'content': json.dumps({ |
|
|
'image_path': image_path, |
|
|
f'{tool_name}':(tool_result), |
|
|
}), |
|
|
'tool_call_id': tool_call_id, |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
completion_payload = { |
|
|
'model': 'gpt-4o', |
|
|
'messages': [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
] |
|
|
}, |
|
|
response.choices[0].message, |
|
|
*results |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model=completion_payload["model"], |
|
|
messages=completion_payload["messages"], |
|
|
response_format={ 'type': 'json_object' }, |
|
|
temperature=0 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gpt_output = [json.loads(response.choices[0].message.content)] |
|
|
print(f"gpt_output_mol:{gpt_output}") |
|
|
|
|
|
|
|
|
def get_multi_molecular(image_path: str) -> list: |
|
|
'''Returns a list of reactions extracted from the image.''' |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
|
|
|
coref_results = model.extract_molecule_corefs_from_figures([image]) |
|
|
return coref_results |
|
|
|
|
|
|
|
|
coref_results = get_multi_molecular(image_path) |
|
|
|
|
|
|
|
|
def update_symbols_in_atoms(input1, input2): |
|
|
""" |
|
|
用 input1 中更新后的 'symbols' 替换 input2 中对应 bboxes 的 'symbols',并同步更新 'atoms' 的 'atom_symbol'。 |
|
|
假设 input1 和 input2 的结构一致。 |
|
|
""" |
|
|
for item1, item2 in zip(input1, input2): |
|
|
bboxes1 = item1.get('bboxes', []) |
|
|
bboxes2 = item2.get('bboxes', []) |
|
|
|
|
|
if len(bboxes1) != len(bboxes2): |
|
|
print("Warning: Mismatched number of bboxes!") |
|
|
continue |
|
|
|
|
|
for bbox1, bbox2 in zip(bboxes1, bboxes2): |
|
|
|
|
|
if 'symbols' in bbox1: |
|
|
bbox2['symbols'] = bbox1['symbols'] |
|
|
|
|
|
|
|
|
if 'symbols' in bbox1 and 'atoms' in bbox2: |
|
|
symbols = bbox1['symbols'] |
|
|
atoms = bbox2.get('atoms', []) |
|
|
|
|
|
|
|
|
if len(symbols) != len(atoms): |
|
|
print(f"Warning: Mismatched symbols and atoms in bbox {bbox1.get('bbox')}!") |
|
|
continue |
|
|
|
|
|
for atom, symbol in zip(atoms, symbols): |
|
|
atom['atom_symbol'] = symbol |
|
|
|
|
|
return input2 |
|
|
|
|
|
|
|
|
input2_updated = update_symbols_in_atoms(gpt_output, coref_results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_smiles_and_molfile(input_data, conversion_function): |
|
|
""" |
|
|
使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile', |
|
|
并替换到原数据结构中。 |
|
|
|
|
|
参数: |
|
|
- input_data: 包含 bboxes 的嵌套数据结构 |
|
|
- conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _) |
|
|
|
|
|
返回: |
|
|
- 更新后的数据结构 |
|
|
""" |
|
|
for item in input_data: |
|
|
for bbox in item.get('bboxes', []): |
|
|
|
|
|
if all(key in bbox for key in ['coords', 'symbols', 'edges']): |
|
|
coords = bbox['coords'] |
|
|
symbols = bbox['symbols'] |
|
|
edges = bbox['edges'] |
|
|
|
|
|
|
|
|
new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges) |
|
|
|
|
|
|
|
|
|
|
|
bbox['smiles'] = new_smiles |
|
|
bbox['molfile'] = new_molfile |
|
|
|
|
|
return input_data |
|
|
|
|
|
updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles) |
|
|
print(f"mol_agent_output:{updated_data}") |
|
|
|
|
|
return updated_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_reaction_image_with_multiple_products_and_text_correctmultiR(image_path: str) -> dict: |
|
|
""" |
|
|
|
|
|
|
|
|
Args: |
|
|
image_path (str): 图像文件路径。 |
|
|
|
|
|
Returns: |
|
|
dict: 整理后的反应数据,包括反应物、产物和反应模板。 |
|
|
""" |
|
|
|
|
|
API_KEY = os.getenv("API_KEY") |
|
|
AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT") |
|
|
client = AzureOpenAI( |
|
|
api_key=API_KEY, |
|
|
api_version='2024-06-01', |
|
|
azure_endpoint=AZURE_ENDPOINT |
|
|
) |
|
|
|
|
|
|
|
|
def encode_image(image_path: str): |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
base64_image = encode_image(image_path) |
|
|
|
|
|
|
|
|
tools = [ |
|
|
{ |
|
|
'type': 'function', |
|
|
'function': { |
|
|
'name': 'get_multi_molecular_text_to_correct_withatoms', |
|
|
'description': 'Extracts the SMILES string, the symbols set, and the text coref of all molecular images in a table-reaction image and ready to be correct.', |
|
|
'parameters': { |
|
|
'type': 'object', |
|
|
'properties': { |
|
|
'image_path': { |
|
|
'type': 'string', |
|
|
'description': 'The path to the reaction image.', |
|
|
}, |
|
|
}, |
|
|
'required': ['image_path'], |
|
|
'additionalProperties': False, |
|
|
}, |
|
|
}, |
|
|
}, |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
with open('./prompt/prompt_getmolecular_correctmultiR.txt', 'r') as prompt_file: |
|
|
prompt = prompt_file.read() |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{'type': 'text', 'text': prompt}, |
|
|
{'type': 'image_url', 'image_url': {'url': f'data:image/png;base64,{base64_image}'}} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model = 'gpt-4o', |
|
|
temperature = 0, |
|
|
response_format={ 'type': 'json_object' }, |
|
|
messages = [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
]}, |
|
|
], |
|
|
tools = tools) |
|
|
|
|
|
|
|
|
TOOL_MAP = { |
|
|
'get_multi_molecular_text_to_correct_withatoms': get_multi_molecular_text_to_correct_withatoms, |
|
|
} |
|
|
|
|
|
|
|
|
tool_calls = response.choices[0].message.tool_calls |
|
|
results = [] |
|
|
|
|
|
|
|
|
for tool_call in tool_calls: |
|
|
tool_name = tool_call.function.name |
|
|
tool_arguments = tool_call.function.arguments |
|
|
tool_call_id = tool_call.id |
|
|
|
|
|
tool_args = json.loads(tool_arguments) |
|
|
|
|
|
if tool_name in TOOL_MAP: |
|
|
|
|
|
tool_result = TOOL_MAP[tool_name](image_path) |
|
|
else: |
|
|
raise ValueError(f"Unknown tool called: {tool_name}") |
|
|
|
|
|
|
|
|
results.append({ |
|
|
'role': 'tool', |
|
|
'content': json.dumps({ |
|
|
'image_path': image_path, |
|
|
f'{tool_name}':(tool_result), |
|
|
}), |
|
|
'tool_call_id': tool_call_id, |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
completion_payload = { |
|
|
'model': 'gpt-4o', |
|
|
'messages': [ |
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'}, |
|
|
{ |
|
|
'role': 'user', |
|
|
'content': [ |
|
|
{ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}, |
|
|
{ |
|
|
'type': 'image_url', |
|
|
'image_url': { |
|
|
'url': f'data:image/png;base64,{base64_image}' |
|
|
} |
|
|
} |
|
|
] |
|
|
}, |
|
|
response.choices[0].message, |
|
|
*results |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model=completion_payload["model"], |
|
|
messages=completion_payload["messages"], |
|
|
response_format={ 'type': 'json_object' }, |
|
|
temperature=0 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gpt_output = [json.loads(response.choices[0].message.content)] |
|
|
print(f"gpt_output_mol:{gpt_output}") |
|
|
|
|
|
|
|
|
def get_multi_molecular(image_path: str) -> list: |
|
|
'''Returns a list of reactions extracted from the image.''' |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
|
|
|
coref_results = model.extract_molecule_corefs_from_figures([image]) |
|
|
return coref_results |
|
|
|
|
|
|
|
|
coref_results = get_multi_molecular(image_path) |
|
|
|
|
|
|
|
|
def update_symbols_and_corefs(gpt_outputs, coref_results): |
|
|
results = [] |
|
|
for item1, item2 in zip(gpt_outputs, coref_results): |
|
|
orig_bboxes = item2.get('bboxes', []) |
|
|
orig_corefs = item2.get('corefs', []) |
|
|
|
|
|
coord2idx = {tuple(bb['bbox']): i for i, bb in enumerate(orig_bboxes)} |
|
|
new_bboxes = [] |
|
|
for bb1 in item1.get('bboxes', []): |
|
|
coord = tuple(bb1['bbox']) |
|
|
if coord in coord2idx: |
|
|
bb_template = orig_bboxes[coord2idx[coord]] |
|
|
else: |
|
|
raise ValueError(f"扩展mol时未找到bbox {coord} 的原始模板!") |
|
|
bb_new = copy.deepcopy(bb_template) |
|
|
if 'symbols' in bb1: |
|
|
bb_new['symbols'] = bb1['symbols'] |
|
|
if 'atoms' in bb_new: |
|
|
for atom, sym in zip(bb_new['atoms'], bb1['symbols']): |
|
|
atom['atom_symbol'] = sym |
|
|
if 'text' in bb1: |
|
|
bb_new['text'] = bb1['text'] |
|
|
bb_new['bbox'] = bb1['bbox'] |
|
|
new_bboxes.append(bb_new) |
|
|
|
|
|
|
|
|
|
|
|
coord2new_idxs = {} |
|
|
for idx, bb in enumerate(new_bboxes): |
|
|
coord = tuple(bb['bbox']) |
|
|
coord2new_idxs.setdefault(coord, []).append(idx) |
|
|
new_corefs = [] |
|
|
for group in orig_corefs: |
|
|
|
|
|
label_idx = group[-1] |
|
|
label_coord = tuple(orig_bboxes[label_idx]['bbox']) |
|
|
new_label_idx = coord2new_idxs[label_coord][-1] |
|
|
|
|
|
for mol_idx in group[:-1]: |
|
|
mol_coord = tuple(orig_bboxes[mol_idx]['bbox']) |
|
|
for new_mol_idx in coord2new_idxs[mol_coord]: |
|
|
new_corefs.append([new_mol_idx, new_label_idx]) |
|
|
|
|
|
new_item = copy.deepcopy(item2) |
|
|
new_item['bboxes'] = new_bboxes |
|
|
new_item['corefs'] = new_corefs |
|
|
results.append(new_item) |
|
|
return results |
|
|
|
|
|
|
|
|
input2_updated = update_symbols_and_corefs(gpt_output, coref_results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_smiles_and_molfile(input_data, conversion_function): |
|
|
""" |
|
|
使用更新后的 'symbols'、'coords' 和 'edges' 调用 `conversion_function` 生成新的 'smiles' 和 'molfile', |
|
|
并替换到原数据结构中。 |
|
|
|
|
|
参数: |
|
|
- input_data: 包含 bboxes 的嵌套数据结构 |
|
|
- conversion_function: 函数,接受 'coords', 'symbols', 'edges' 并返回 (new_smiles, new_molfile, _) |
|
|
|
|
|
返回: |
|
|
- 更新后的数据结构 |
|
|
""" |
|
|
for item in input_data: |
|
|
for bbox in item.get('bboxes', []): |
|
|
|
|
|
if all(key in bbox for key in ['coords', 'symbols', 'edges']): |
|
|
coords = bbox['coords'] |
|
|
symbols = bbox['symbols'] |
|
|
edges = bbox['edges'] |
|
|
|
|
|
|
|
|
new_smiles, new_molfile, _ = conversion_function(coords, symbols, edges) |
|
|
|
|
|
|
|
|
|
|
|
bbox['smiles'] = new_smiles |
|
|
bbox['molfile'] = new_molfile |
|
|
|
|
|
return input_data |
|
|
|
|
|
updated_data = update_smiles_and_molfile(input2_updated, _convert_graph_to_smiles) |
|
|
print(f"mol_agent_output:{updated_data}") |
|
|
|
|
|
return updated_data |
|
|
|
|
|
|