|
|
import gradio as gr |
|
|
import json |
|
|
from main import ChemEagle |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import rdChemReactions, Draw, AllChem |
|
|
from rdkit.Chem.Draw import rdMolDraw2D |
|
|
import cairosvg |
|
|
import re |
|
|
import os |
|
|
|
|
|
example_diagram = "examples/exp.png" |
|
|
rdkit_image = "examples/rdkit.png" |
|
|
|
|
|
|
|
|
def parse_reactions(output_json): |
|
|
if isinstance(output_json, str): |
|
|
reactions_data = json.loads(output_json) |
|
|
else: |
|
|
reactions_data = output_json |
|
|
reactions_list = reactions_data.get("reactions", []) |
|
|
detailed_output = [] |
|
|
smiles_output = [] |
|
|
|
|
|
for reaction in reactions_list: |
|
|
reaction_id = reaction.get("reaction_id", "Unknown ID") |
|
|
reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])] |
|
|
conds = reaction.get("conditions") |
|
|
if conds is None: |
|
|
conds = reaction.get("condition", []) |
|
|
conditions = [ |
|
|
f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>" |
|
|
for c in conds |
|
|
] |
|
|
conditions_1 = [ |
|
|
f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>" |
|
|
for c in conds |
|
|
] |
|
|
products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])] |
|
|
products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])] |
|
|
products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])] |
|
|
additional = reaction.get("additional_info", []) |
|
|
additional_str = [str(x) for x in additional if x] |
|
|
|
|
|
tail = conditions_1 + additional_str |
|
|
tail_str = ", ".join(tail) |
|
|
full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}" |
|
|
full_reaction = f"<span style='color:black'>{full_reaction}</span>" |
|
|
|
|
|
reaction_output = f"<b>Reaction: </b> {reaction_id}<br>" |
|
|
reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>" |
|
|
reaction_output += f" Conditions: {', '.join(conditions)}<br>" |
|
|
reaction_output += f" Products: {', '.join(products)}<br>" |
|
|
reaction_output += f" additional_info: {', '.join(additional_str)}<br>" |
|
|
reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br><br>" |
|
|
detailed_output.append(reaction_output) |
|
|
|
|
|
reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}" |
|
|
smiles_output.append(reaction_smiles) |
|
|
|
|
|
return detailed_output, smiles_output |
|
|
|
|
|
|
|
|
def parse_mol(output_json): |
|
|
""" |
|
|
解析单分子/多分子的 ChemEagle 输出,返回与 parse_reactions 相同的 detailed_output, smiles_output。 |
|
|
""" |
|
|
if isinstance(output_json, str): |
|
|
mols_data = json.loads(output_json) |
|
|
else: |
|
|
mols_data = output_json |
|
|
molecules_list = mols_data.get("molecules", []) |
|
|
detailed_output = [] |
|
|
smiles_output = [] |
|
|
|
|
|
for i, mol in enumerate(molecules_list): |
|
|
smiles = mol.get("smiles", "Unknown") |
|
|
label = mol.get("label", f"Mol {i+1}") |
|
|
bbox = mol.get("bbox", []) |
|
|
|
|
|
mol_output = f"<b>Molecule:</b> {label}<br>" \ |
|
|
f" SMILES: <span style='color:blue'>{smiles}</span><br>" \ |
|
|
f" bbox: {bbox}<br><br>" |
|
|
detailed_output.append(mol_output) |
|
|
smiles_output.append(smiles) |
|
|
return detailed_output, smiles_output |
|
|
|
|
|
|
|
|
|
|
|
def process_chem_image(image): |
|
|
image_path = "temp_image.png" |
|
|
image.save(image_path) |
|
|
|
|
|
chemeagle_result = ChemEagle(image_path) |
|
|
if "molecules" in chemeagle_result: |
|
|
detailed, smiles = parse_mol(chemeagle_result) |
|
|
else: |
|
|
detailed, smiles = parse_reactions(chemeagle_result) |
|
|
|
|
|
json_path = "output.json" |
|
|
with open(json_path, 'w') as jf: |
|
|
json.dump(chemeagle_result, jf, indent=2) |
|
|
|
|
|
return "\n\n".join(detailed), smiles, example_diagram, json_path |
|
|
|
|
|
|
|
|
|
|
|
def process_chem_image_api(image, api_key, endpoint): |
|
|
|
|
|
os.environ["API_KEY"] = api_key |
|
|
os.environ["AZURE_ENDPOINT"] = endpoint or "" |
|
|
|
|
|
image_path = "temp_image.png" |
|
|
image.save(image_path) |
|
|
|
|
|
|
|
|
chemeagle_result = ChemEagle(image_path) |
|
|
if "molecules" in chemeagle_result: |
|
|
detailed, smiles = parse_mol(chemeagle_result) |
|
|
else: |
|
|
detailed, smiles = parse_reactions(chemeagle_result) |
|
|
|
|
|
json_path = "output.json" |
|
|
with open(json_path, 'w') as jf: |
|
|
json.dump(chemeagle_result, jf, indent=2) |
|
|
|
|
|
return "\n\n".join(detailed), smiles, example_diagram, json_path |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
<center><h1>ChemEAGLE: A Multi-Agent System Enables Versatile Information Extraction from the Chemical Literature</h1></center> |
|
|
Upload a chemical graphic to extract machine-readable chemical data. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image(type="pil", label="Upload a chemical graphic") |
|
|
api_key_input = gr.Textbox(label="Azure API Key", type="password", placeholder="Enter your Azure API Key") |
|
|
endpoint_input = gr.Textbox(label="Azure Endpoint", placeholder="e.g. https://xxx.openai.azure.com/") |
|
|
|
|
|
with gr.Row(): |
|
|
clear_btn = gr.Button("Clear") |
|
|
run_btn = gr.Button("Run", elem_id="submit-btn") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Parsed Reactions") |
|
|
reaction_output = gr.HTML(label="Detailed Reaction Output") |
|
|
gr.Markdown("### Schematic Diagram") |
|
|
schematic_diagram = gr.Image(value=example_diagram, label="示意图") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Machine-readable Output") |
|
|
smiles_output = gr.Textbox( |
|
|
label="Reaction SMILES", |
|
|
show_copy_button=True, |
|
|
interactive=False, |
|
|
visible=False |
|
|
) |
|
|
|
|
|
@gr.render(inputs=smiles_output) |
|
|
def show_split(inputs): |
|
|
if not inputs or (isinstance(inputs, str) and inputs.strip() == ""): |
|
|
return gr.Textbox(label="SMILES of Reaction or Molecule i"), gr.Image(value=rdkit_image, label="RDKit Image", height=100) |
|
|
smiles_list = inputs.split(",") |
|
|
smiles_list = [re.sub(r"^\s*\[?'?|']?\s*$", "", item) for item in smiles_list] |
|
|
components = [] |
|
|
for i, smiles in enumerate(smiles_list): |
|
|
smiles_clean = smiles.replace('"', '').replace("'", "") |
|
|
|
|
|
components.append(gr.Textbox(value=smiles_clean, label=f"SMILES of Item {i}", show_copy_button=True, interactive=False)) |
|
|
try: |
|
|
|
|
|
rxn = rdChemReactions.ReactionFromSmarts(smiles_clean, useSmiles=True) |
|
|
is_rxn = rxn is not None and rxn.GetNumProductTemplates() > 0 |
|
|
except Exception: |
|
|
is_rxn = False |
|
|
|
|
|
if is_rxn: |
|
|
try: |
|
|
new_rxn = AllChem.ChemicalReaction() |
|
|
for mol in rxn.GetReactants(): |
|
|
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) |
|
|
new_rxn.AddReactantTemplate(mol) |
|
|
for mol in rxn.GetProducts(): |
|
|
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) |
|
|
new_rxn.AddProductTemplate(mol) |
|
|
cleaned_rxn = new_rxn |
|
|
|
|
|
for react in cleaned_rxn.GetReactants(): |
|
|
for atom in react.GetAtoms(): atom.SetAtomMapNum(0) |
|
|
for prod in cleaned_rxn.GetProducts(): |
|
|
for atom in prod.GetAtoms(): atom.SetAtomMapNum(0) |
|
|
|
|
|
react0 = cleaned_rxn.GetReactantTemplate(0) |
|
|
react1 = cleaned_rxn.GetReactantTemplate(1) if cleaned_rxn.GetNumReactantTemplates() > 1 else None |
|
|
if react0.GetNumBonds() > 0: |
|
|
bond_len = Draw.MeanBondLength(react0) |
|
|
elif react1 and react1.GetNumBonds() > 0: |
|
|
bond_len = Draw.MeanBondLength(react1) |
|
|
else: |
|
|
bond_len = 1.0 |
|
|
|
|
|
drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1) |
|
|
dopts = drawer.drawOptions() |
|
|
dopts.padding = 0.1 |
|
|
dopts.includeRadicals = True |
|
|
Draw.SetACS1996Mode(dopts, bond_len * 0.55) |
|
|
dopts.bondLineWidth = 1.5 |
|
|
drawer.DrawReaction(cleaned_rxn) |
|
|
drawer.FinishDrawing() |
|
|
svg = drawer.GetDrawingText() |
|
|
svg_file = f"reaction_{i}.svg" |
|
|
with open(svg_file, "w") as f: f.write(svg) |
|
|
png_file = f"reaction_{i}.png" |
|
|
cairosvg.svg2png(url=svg_file, write_to=png_file) |
|
|
components.append(gr.Image(value=png_file, label=f"RDKit Image of Reaction {i}")) |
|
|
except Exception as e: |
|
|
print(f"Failed to draw reaction {i} for SMILES '{smiles_clean}': {e}") |
|
|
else: |
|
|
|
|
|
try: |
|
|
mol = Chem.MolFromSmiles(smiles_clean) |
|
|
if mol: |
|
|
img = Draw.MolToImage(mol, size=(350, 150)) |
|
|
img_file = f"mol_{i}.png" |
|
|
img.save(img_file) |
|
|
components.append(gr.Image(value=img_file, label=f"RDKit Image of Molecule {i}")) |
|
|
else: |
|
|
components.append(gr.Image(value=rdkit_image, label="Invalid Molecule")) |
|
|
except Exception as e: |
|
|
print(f"Failed to draw molecule {i} for SMILES '{smiles_clean}': {e}") |
|
|
components.append(gr.Image(value=rdkit_image, label="Invalid Molecule")) |
|
|
return components |
|
|
|
|
|
download_json = gr.File(label="Download JSON File") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["examples/reaction0.png"], |
|
|
["examples/reaction1.jpg"], |
|
|
["examples/reaction2.png"], |
|
|
["examples/reaction3.png"], |
|
|
["examples/reaction4.png"], |
|
|
["examples/reaction5.png"], |
|
|
["examples/template1.png"], |
|
|
["examples/molecules1.png"], |
|
|
|
|
|
], |
|
|
inputs=[image_input], |
|
|
outputs=[reaction_output, smiles_output, schematic_diagram, download_json], |
|
|
cache_examples=False, |
|
|
examples_per_page=10, |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
lambda: (None, None, None, None), |
|
|
inputs=[], |
|
|
outputs=[image_input, reaction_output, smiles_output, download_json] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
run_btn.click( |
|
|
process_chem_image_api, |
|
|
inputs=[image_input, api_key_input, endpoint_input], |
|
|
outputs=[reaction_output, smiles_output, schematic_diagram, download_json] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
demo.css = """ |
|
|
#submit-btn { |
|
|
background-color: #FF914D; |
|
|
color: white; |
|
|
font-weight: bold; |
|
|
} |
|
|
""" |
|
|
|
|
|
demo.launch() |
|
|
|