ChemEagle / app.py
CYF200127's picture
Update app.py
e5bd3d6 verified
import gradio as gr
import json
from main import ChemEagle
from rdkit import Chem
from rdkit.Chem import rdChemReactions, Draw, AllChem
from rdkit.Chem.Draw import rdMolDraw2D
import cairosvg
import re
import os
example_diagram = "examples/exp.png"
rdkit_image = "examples/rdkit.png"
def parse_reactions(output_json):
if isinstance(output_json, str):
reactions_data = json.loads(output_json)
else:
reactions_data = output_json
reactions_list = reactions_data.get("reactions", [])
detailed_output = []
smiles_output = []
for reaction in reactions_list:
reaction_id = reaction.get("reaction_id", "Unknown ID")
reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])]
conds = reaction.get("conditions")
if conds is None:
conds = reaction.get("condition", [])
conditions = [
f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
for c in conds
]
conditions_1 = [
f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
for c in conds
]
products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])]
additional = reaction.get("additional_info", [])
additional_str = [str(x) for x in additional if x]
tail = conditions_1 + additional_str
tail_str = ", ".join(tail)
full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}"
full_reaction = f"<span style='color:black'>{full_reaction}</span>"
reaction_output = f"<b>Reaction: </b> {reaction_id}<br>"
reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>"
reaction_output += f" Conditions: {', '.join(conditions)}<br>"
reaction_output += f" Products: {', '.join(products)}<br>"
reaction_output += f" additional_info: {', '.join(additional_str)}<br>"
reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br><br>"
detailed_output.append(reaction_output)
reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}"
smiles_output.append(reaction_smiles)
return detailed_output, smiles_output
def parse_mol(output_json):
"""
解析单分子/多分子的 ChemEagle 输出,返回与 parse_reactions 相同的 detailed_output, smiles_output。
"""
if isinstance(output_json, str):
mols_data = json.loads(output_json)
else:
mols_data = output_json
molecules_list = mols_data.get("molecules", [])
detailed_output = []
smiles_output = []
for i, mol in enumerate(molecules_list):
smiles = mol.get("smiles", "Unknown")
label = mol.get("label", f"Mol {i+1}")
bbox = mol.get("bbox", [])
# 可自定义格式
mol_output = f"<b>Molecule:</b> {label}<br>" \
f" SMILES: <span style='color:blue'>{smiles}</span><br>" \
f" bbox: {bbox}<br><br>"
detailed_output.append(mol_output)
smiles_output.append(smiles)
return detailed_output, smiles_output
def process_chem_image(image):
image_path = "temp_image.png"
image.save(image_path)
chemeagle_result = ChemEagle(image_path)
if "molecules" in chemeagle_result:
detailed, smiles = parse_mol(chemeagle_result)
else:
detailed, smiles = parse_reactions(chemeagle_result)
json_path = "output.json"
with open(json_path, 'w') as jf:
json.dump(chemeagle_result, jf, indent=2)
return "\n\n".join(detailed), smiles, example_diagram, json_path
def process_chem_image_api(image, api_key, endpoint):
# 动态设置环境变量
os.environ["API_KEY"] = api_key
os.environ["AZURE_ENDPOINT"] = endpoint or "" # endpoint 允许为空
image_path = "temp_image.png"
image.save(image_path)
# 假设 ChemEagle 内部会用 os.getenv("API_KEY") 或 os.environ["API_KEY"]
chemeagle_result = ChemEagle(image_path)
if "molecules" in chemeagle_result:
detailed, smiles = parse_mol(chemeagle_result)
else:
detailed, smiles = parse_reactions(chemeagle_result)
json_path = "output.json"
with open(json_path, 'w') as jf:
json.dump(chemeagle_result, jf, indent=2)
return "\n\n".join(detailed), smiles, example_diagram, json_path
with gr.Blocks() as demo:
gr.Markdown(
"""
<center><h1>ChemEAGLE: A Multi-Agent System Enables Versatile Information Extraction from the Chemical Literature</h1></center>
Upload a chemical graphic to extract machine-readable chemical data.
"""
)
#####api
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload a chemical graphic")
api_key_input = gr.Textbox(label="Azure API Key", type="password", placeholder="Enter your Azure API Key")
endpoint_input = gr.Textbox(label="Azure Endpoint", placeholder="e.g. https://xxx.openai.azure.com/")
with gr.Row():
clear_btn = gr.Button("Clear")
run_btn = gr.Button("Run", elem_id="submit-btn")
with gr.Column(scale=1):
gr.Markdown("### Parsed Reactions")
reaction_output = gr.HTML(label="Detailed Reaction Output")
gr.Markdown("### Schematic Diagram")
schematic_diagram = gr.Image(value=example_diagram, label="示意图")
with gr.Column(scale=1):
gr.Markdown("### Machine-readable Output")
smiles_output = gr.Textbox(
label="Reaction SMILES",
show_copy_button=True,
interactive=False,
visible=False
)
@gr.render(inputs=smiles_output)
def show_split(inputs):
if not inputs or (isinstance(inputs, str) and inputs.strip() == ""):
return gr.Textbox(label="SMILES of Reaction or Molecule i"), gr.Image(value=rdkit_image, label="RDKit Image", height=100)
smiles_list = inputs.split(",")
smiles_list = [re.sub(r"^\s*\[?'?|']?\s*$", "", item) for item in smiles_list]
components = []
for i, smiles in enumerate(smiles_list):
smiles_clean = smiles.replace('"', '').replace("'", "")
# 始终加入 SMILES 文本框
components.append(gr.Textbox(value=smiles_clean, label=f"SMILES of Item {i}", show_copy_button=True, interactive=False))
try:
# 优先判断是否为 reaction smiles
rxn = rdChemReactions.ReactionFromSmarts(smiles_clean, useSmiles=True)
is_rxn = rxn is not None and rxn.GetNumProductTemplates() > 0
except Exception:
is_rxn = False
if is_rxn:
try:
new_rxn = AllChem.ChemicalReaction()
for mol in rxn.GetReactants():
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
new_rxn.AddReactantTemplate(mol)
for mol in rxn.GetProducts():
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
new_rxn.AddProductTemplate(mol)
cleaned_rxn = new_rxn
for react in cleaned_rxn.GetReactants():
for atom in react.GetAtoms(): atom.SetAtomMapNum(0)
for prod in cleaned_rxn.GetProducts():
for atom in prod.GetAtoms(): atom.SetAtomMapNum(0)
react0 = cleaned_rxn.GetReactantTemplate(0)
react1 = cleaned_rxn.GetReactantTemplate(1) if cleaned_rxn.GetNumReactantTemplates() > 1 else None
if react0.GetNumBonds() > 0:
bond_len = Draw.MeanBondLength(react0)
elif react1 and react1.GetNumBonds() > 0:
bond_len = Draw.MeanBondLength(react1)
else:
bond_len = 1.0
drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)
dopts = drawer.drawOptions()
dopts.padding = 0.1
dopts.includeRadicals = True
Draw.SetACS1996Mode(dopts, bond_len * 0.55)
dopts.bondLineWidth = 1.5
drawer.DrawReaction(cleaned_rxn)
drawer.FinishDrawing()
svg = drawer.GetDrawingText()
svg_file = f"reaction_{i}.svg"
with open(svg_file, "w") as f: f.write(svg)
png_file = f"reaction_{i}.png"
cairosvg.svg2png(url=svg_file, write_to=png_file)
components.append(gr.Image(value=png_file, label=f"RDKit Image of Reaction {i}"))
except Exception as e:
print(f"Failed to draw reaction {i} for SMILES '{smiles_clean}': {e}")
else:
# 尝试分子绘图
try:
mol = Chem.MolFromSmiles(smiles_clean)
if mol:
img = Draw.MolToImage(mol, size=(350, 150))
img_file = f"mol_{i}.png"
img.save(img_file)
components.append(gr.Image(value=img_file, label=f"RDKit Image of Molecule {i}"))
else:
components.append(gr.Image(value=rdkit_image, label="Invalid Molecule"))
except Exception as e:
print(f"Failed to draw molecule {i} for SMILES '{smiles_clean}': {e}")
components.append(gr.Image(value=rdkit_image, label="Invalid Molecule"))
return components
download_json = gr.File(label="Download JSON File")
gr.Examples(
examples=[
["examples/reaction0.png"],
["examples/reaction1.jpg"],
["examples/reaction2.png"],
["examples/reaction3.png"],
["examples/reaction4.png"],
["examples/reaction5.png"],
["examples/template1.png"],
["examples/molecules1.png"],
],
inputs=[image_input],
outputs=[reaction_output, smiles_output, schematic_diagram, download_json],
cache_examples=False,
examples_per_page=10,
)
clear_btn.click(
lambda: (None, None, None, None),
inputs=[],
outputs=[image_input, reaction_output, smiles_output, download_json]
)
######api
run_btn.click(
process_chem_image_api,
inputs=[image_input, api_key_input, endpoint_input],
outputs=[reaction_output, smiles_output, schematic_diagram, download_json]
)
demo.css = """
#submit-btn {
background-color: #FF914D;
color: white;
font-weight: bold;
}
"""
demo.launch()