Spaces:
Running
Running
| import os, sys, argparse, tempfile, shutil, base64, io | |
| from flask import Flask, request, render_template_string | |
| from werkzeug.utils import secure_filename | |
| from torch.utils.data import DataLoader | |
| import selfies | |
| from rdkit import Chem | |
| import torch | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib import cm | |
| from typing import Optional | |
| from utils.drug_tokenizer import DrugTokenizer | |
| from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel | |
| from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI | |
| from utils.foldseek_util import get_struc_seq | |
| # ───── Biopython fallback ─────────────────────────────────────── | |
| from Bio.PDB import PDBParser, MMCIFParser | |
| from Bio.Data import IUPACData | |
| three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()} | |
| three2one.update({"SEC": "C", "PYL": "K"}) | |
| def simple_seq_from_structure(path: str) -> str: | |
| parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True) | |
| chain = next(parser.get_structure("P", path).get_chains()) | |
| return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain) | |
| # ───── global paths / args ────────────────────────────────────── | |
| FOLDSEEK_BIN = shutil.which("foldseek") | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| sys.path.append("..") | |
| def parse_config(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("-f") | |
| p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2") | |
| p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer") | |
| p.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}") | |
| p.add_argument("--group_size", type=int, default=1) | |
| p.add_argument("--lr", type=float, default=1e-4) | |
| p.add_argument("--fusion", default="CAN") | |
| p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") | |
| p.add_argument("--save_path_prefix", default="save_model_ckp/") | |
| p.add_argument("--dataset", default="BindingDB", | |
| help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')") | |
| return p.parse_args() | |
| args = parse_config() | |
| DEVICE = args.device | |
| # ───── tokenisers & encoders ──────────────────────────────────── | |
| prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path) | |
| prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path) | |
| drug_tokenizer = DrugTokenizer() # SELFIES | |
| drug_model = AutoModel.from_pretrained(args.drug_encoder_path) | |
| encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE) | |
| # ─── collate fn ──────────────────────────────────────────────── | |
| def collate_fn(batch): | |
| query1, query2, scores = zip(*batch) | |
| query_encodings1 = prot_tokenizer.batch_encode_plus( | |
| list(query1), | |
| max_length=512, | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| query_encodings2 = drug_tokenizer.batch_encode_plus( | |
| list(query2), | |
| max_length=512, | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| scores = torch.tensor(list(scores)) | |
| attention_mask1 = query_encodings1["attention_mask"].bool() | |
| attention_mask2 = query_encodings2["attention_mask"].bool() | |
| return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores | |
| # def collate_fn_batch_encoding(batch): | |
| def smiles_to_selfies(smiles: str) -> Optional[str]: | |
| try: | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return None | |
| selfies_str = selfies.encoder(smiles) | |
| return selfies_str | |
| except Exception: | |
| return None | |
| # ───── single-case embedding ─────────────────────────────────── | |
| def get_case_feature(model, loader): | |
| model.eval() | |
| with torch.no_grad(): | |
| for p_ids, p_mask, d_ids, d_mask, _ in loader: | |
| p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE) | |
| d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE) | |
| p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask) | |
| return [(p_emb.cpu(), d_emb.cpu(), | |
| p_ids.cpu(), d_ids.cpu(), | |
| p_mask.cpu(), d_mask.cpu(), None)] | |
| # ───── helper:过滤特殊 token ─────────────────────────────────── | |
| def clean_tokens(ids, tokenizer): | |
| toks = tokenizer.convert_ids_to_tokens(ids.tolist()) | |
| return [t for t in toks if t not in tokenizer.all_special_tokens] | |
| # ───── visualisation ─────────────────────────────────────────── | |
| def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str: | |
| """ | |
| Render a Protein → Drug cross-attention heat-map and, optionally, a | |
| Top-20 protein-residue table for a chosen drug-token index. | |
| The token index shown on the x-axis (and accepted via *drug_idx*) is **the | |
| position of that token in the *original* drug sequence**, *after* the | |
| tokeniser but *before* any pruning or truncation (1-based in the labels, | |
| 0-based for the function argument). | |
| Returns | |
| ------- | |
| html : str | |
| Base64-embedded PNG heat-map (+ optional HTML table). | |
| """ | |
| model.eval() | |
| with torch.no_grad(): | |
| # ── unpack single-case tensors ─────────────────────────────────────────── | |
| p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0] | |
| p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE) | |
| p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE) | |
| # ── forward pass: Protein → Drug attention (B, n_p, n_d) ─────────────── | |
| _, att_pd = model(p_emb, d_emb, p_mask, d_mask) | |
| attn = att_pd.squeeze(0).cpu() # (n_p, n_d) | |
| # ── decode tokens (skip special symbols) ──────────────────────────────── | |
| def clean_ids(ids, tokenizer): | |
| toks = tokenizer.convert_ids_to_tokens(ids.tolist()) | |
| return [t for t in toks if t not in tokenizer.all_special_tokens] | |
| # ── decode full sequences + record 1-based indices ────────────────── | |
| p_tokens_full = clean_ids(p_ids[0], prot_tokenizer) | |
| p_indices_full = list(range(1, len(p_tokens_full) + 1)) | |
| d_tokens_full = clean_ids(d_ids[0], drug_tokenizer) | |
| d_indices_full = list(range(1, len(d_tokens_full) + 1)) | |
| # ── safety cut-off to match attn mat size ─────────────────────────────── | |
| p_tokens = p_tokens_full[: attn.size(0)] | |
| p_indices_full = p_indices_full[: attn.size(0)] | |
| d_tokens_full = d_tokens_full[: attn.size(1)] | |
| d_indices_full = d_indices_full[: attn.size(1)] | |
| attn = attn[: len(p_tokens_full), : len(d_tokens_full)] | |
| # ── adaptive sparsity pruning ─────────────────────────────────────────── | |
| thr = attn.max().item() * 0.05 | |
| row_keep = (attn.max(dim=1).values > thr) | |
| col_keep = (attn.max(dim=0).values > thr) | |
| if row_keep.sum() < 3: | |
| row_keep[:] = True | |
| if col_keep.sum() < 3: | |
| col_keep[:] = True | |
| attn = attn[row_keep][:, col_keep] | |
| p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep] | |
| p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep] | |
| d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep] | |
| d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep] | |
| # ── cap column count at 150 for readability ───────────────────────────── | |
| if attn.size(1) > 150: | |
| topc = torch.topk(attn.sum(0), k=150).indices | |
| attn = attn[:, topc] | |
| d_tokens = [d_tokens [i] for i in topc] | |
| d_indices = [d_indices[i] for i in topc] | |
| # ── draw heat-map ─────────────────────────────────────────────────────── | |
| x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)] | |
| y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)] | |
| fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6″ per column | |
| fig_h = min(24, max(6, len(p_tokens) * 0.8)) | |
| fig, ax = plt.subplots(figsize=(fig_w, fig_h)) | |
| im = ax.imshow(attn.numpy(), aspect="auto", | |
| cmap=cm.viridis, interpolation="nearest") | |
| ax.set_title("Protein → Drug Attention", pad=8, fontsize=10) | |
| ax.set_xticks(range(len(x_labels))) | |
| ax.set_xticklabels(x_labels, rotation=90, fontsize=8, | |
| ha="center", va="center") | |
| ax.tick_params(axis="x", top=True, bottom=False, | |
| labeltop=True, labelbottom=False, pad=27) | |
| ax.set_yticks(range(len(y_labels))) | |
| ax.set_yticklabels(y_labels, fontsize=7) | |
| ax.tick_params(axis="y", top=True, bottom=False, | |
| labeltop=True, labelbottom=False, | |
| pad=10) | |
| fig.colorbar(im, fraction=0.026, pad=0.01) | |
| fig.tight_layout() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=140) | |
| plt.close(fig) | |
| html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />' | |
| # ───────────────────── 生成 Top-20 表(若需要) ───────────────────── | |
| table_html = "" # 先设空串,方便后面统一拼接 | |
| if drug_idx is not None: | |
| # map original 0-based drug_idx → current column position | |
| if (drug_idx + 1) in d_indices: | |
| col_pos = d_indices.index(drug_idx + 1) | |
| elif 0 <= drug_idx < len(d_tokens): | |
| col_pos = drug_idx | |
| else: | |
| col_pos = None | |
| if col_pos is not None: | |
| col_vec = attn[:, col_pos] | |
| topk = torch.topk(col_vec, k=min(20, len(col_vec))).indices.tolist() | |
| rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk))) | |
| res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk) | |
| pos_row = "".join(f"<td>{p_indices[i]}</td>"for i in topk) | |
| drug_tok_text = d_tokens[col_pos] | |
| orig_idx = d_indices[col_pos] | |
| table_html = ( | |
| f"<h4 style='margin-bottom:6px'>" | |
| f"Drug token #{orig_idx} <code>{drug_tok_text}</code> " | |
| f"→ Top-20 Protein residues</h4>" | |
| "<table class='tg' style='margin-bottom:8px'>" | |
| f"<tr><th>Rank</th>{rank_hdr}</tr>" | |
| f"<tr><td>Residue</td>{res_row}</tr>" | |
| f"<tr><td>Position</td>{pos_row}</tr>" | |
| "</table>") | |
| # ────────────────── 生成可放大 + 可下载的热图 ──────────────────── | |
| buf_png = io.BytesIO() | |
| fig.savefig(buf_png, format="png", dpi=140) # 预览(光栅) | |
| buf_png.seek(0) | |
| buf_pdf = io.BytesIO() | |
| fig.savefig(buf_pdf, format="pdf") # 高清下载(矢量) | |
| buf_pdf.seek(0) | |
| plt.close(fig) | |
| png_b64 = base64.b64encode(buf_png.getvalue()).decode() | |
| pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode() | |
| html_heat = ( | |
| f"<a href='data:image/png;base64,{png_b64}' target='_blank' " | |
| f"title='Click to enlarge'>" | |
| f"<img src='data:image/png;base64,{png_b64}' " | |
| f"style='max-width:100%;height:auto;cursor:zoom-in' /></a>" | |
| f"<div style='margin-top:6px'>" | |
| f"<a href='data:application/pdf;base64,{pdf_b64}' " | |
| f"download='attention_heatmap.pdf'>Download PDF</a></div>" | |
| ) | |
| # ───────────────────────── 返回最终 HTML ───────────────────────── | |
| return table_html + html_heat | |
| # ───── Flask app ─────────────────────────────────────────────── | |
| app = Flask(__name__) | |
| def index(): | |
| protein_seq = drug_seq = structure_seq = ""; result_html = None | |
| tmp_structure_path = ""; drug_idx = None | |
| if request.method == "POST": | |
| drug_idx_raw = request.form.get("drug_idx", "") | |
| drug_idx = int(drug_idx_raw)-1 if drug_idx_raw.isdigit() else None | |
| struct = request.files.get("structure_file") | |
| if struct and struct.filename: | |
| path = os.path.join(tempfile.gettempdir(), secure_filename(struct.filename)) | |
| struct.save(path); tmp_structure_path = path | |
| else: | |
| tmp_structure_path = request.form.get("tmp_structure_path", "") | |
| if "clear" in request.form: | |
| protein_seq = drug_seq = structure_seq = ""; tmp_structure_path = "" | |
| elif "confirm_structure" in request.form and tmp_structure_path: | |
| try: | |
| parsed = get_struc_seq(FOLDSEEK_BIN, tmp_structure_path, None, plddt_mask=False) | |
| chain = list(parsed.keys())[0]; _, _, structure_seq = parsed[chain] | |
| except Exception: | |
| structure_seq = simple_seq_from_structure(tmp_structure_path) | |
| protein_seq = structure_seq | |
| drug_input = request.form.get("drug_sequence", "") | |
| # Heuristically check if input is SMILES (not starting with [) and convert | |
| if not drug_input.strip().startswith("["): | |
| converted = smiles_to_selfies(drug_input.strip()) | |
| if converted: | |
| drug_seq = converted | |
| else: | |
| drug_seq = "" | |
| result_html = "<p style='color:red'><strong>Failed to convert SMILES to SELFIES. Please check the input string.</strong></p>" | |
| else: | |
| drug_seq = drug_input | |
| elif "Inference" in request.form: | |
| protein_seq = request.form.get("protein_sequence", "") | |
| drug_seq = request.form.get("drug_sequence", "") | |
| if protein_seq and drug_seq: | |
| loader = DataLoader([(protein_seq, drug_seq, 1)], batch_size=1, | |
| collate_fn=collate_fn) | |
| feats = get_case_feature(encoding, loader) | |
| model = FusionDTI(446, 768, args).to(DEVICE) | |
| ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}", | |
| "best_model.ckpt") | |
| if os.path.isfile(ckpt): | |
| model.load_state_dict(torch.load(ckpt, map_location=DEVICE)) | |
| result_html = visualize_attention(model, feats, drug_idx) | |
| return render_template_string( | |
| # ───────────── HTML (原 UI + 新输入框) ───────────── | |
| """ | |
| <!doctype html> | |
| <html lang="en"><head><meta charset="utf-8"><title>FusionDTI </title> | |
| <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=Poppins:wght@500;600&display=swap" rel="stylesheet"> | |
| <style> | |
| :root{--bg:#f3f4f6;--card:#fff;--primary:#6366f1;--primary-dark:#4f46e5;--text:#111827;--border:#e5e7eb;} | |
| *{box-sizing:border-box;margin:0;padding:0} | |
| body{background:var(--bg);color:var(--text);font-family:Inter,system-ui,Arial,sans-serif;line-height:1.5;padding:32px 12px;} | |
| h1{font-family:Poppins,Inter,sans-serif;font-weight:600;font-size:1.7rem;text-align:center;margin-bottom:28px;letter-spacing:-.2px;} | |
| .card{max-width:1000px;margin:0 auto;background:var(--card);border:1px solid var(--border); | |
| border-radius:12px;box-shadow:0 2px 6px rgba(0,0,0,.05);padding:32px 36px;} | |
| label{font-weight:500;margin-bottom:6px;display:block} | |
| textarea,input[type=file]{width:100%;font-size:.9rem;font-family:monospace;padding:10px 12px; | |
| border:1px solid var(--border);border-radius:8px;background:#fff;resize:vertical;} | |
| textarea{min-height:90px} | |
| .btn{appearance:none;border:none;cursor:pointer;padding:12px 22px;border-radius:8px;font-weight:500; | |
| font-family:Inter,sans-serif;transition:all .18s ease;color:#fff;} | |
| .btn-primary{background:var(--primary)}.btn-primary:hover{background:var(--primary-dark)} | |
| .btn-neutral{background:#9ca3af;}.btn-neutral:hover{background:#6b7280} | |
| .grid{display:grid;gap:22px}.grid-2{grid-template-columns:1fr 1fr} | |
| .vis-box{margin-top:28px;border:1px solid var(--border);border-radius:10px;overflow:auto;max-height:72vh;} | |
| pre{white-space:pre-wrap;word-break:break-all;font-family:monospace;margin-top:8px} | |
| /* ── tidy table for Top-20 list ─────────────────────────────── */ | |
| table.tg{border-collapse:collapse;margin-top:4px;font-size:0.83rem} | |
| table.tg th,table.tg td{border:1px solid var(--border);padding:6px 8px;text-align:left} | |
| table.tg th{background:var(--bg);font-weight:600} | |
| </style> | |
| </head> | |
| <body> | |
| <h1> Token-level Visualiser for Drug-Target Interaction</h1> | |
| <!-- ───────────── Project Links (larger + spaced) ───────────── --> | |
| <div style="margin-top:24px; text-align:center;"> | |
| <a href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank" | |
| style="display:inline-block;margin:8px 18px;padding:10px 20px; | |
| background:linear-gradient(to right,#10b981,#059669);color:white; | |
| font-weight:600;border-radius:8px;font-size:0.9rem; | |
| font-family:Inter,sans-serif;text-decoration:none; | |
| box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;" | |
| onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'"> | |
| 🌐 Project Page | |
| </a> | |
| <a href="https://arxiv.org/abs/2406.01651" target="_blank" | |
| style="display:inline-block;margin:8px 18px;padding:10px 20px; | |
| background:linear-gradient(to right,#ef4444,#dc2626);color:white; | |
| font-weight:600;border-radius:8px;font-size:0.9rem; | |
| font-family:Inter,sans-serif;text-decoration:none; | |
| box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;" | |
| onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'"> | |
| 📄 ArXiv: 2406.01651 | |
| </a> | |
| <a href="https://github.com/ZhaohanM/FusionDTI" target="_blank" | |
| style="display:inline-block;margin:8px 18px;padding:10px 20px; | |
| background:linear-gradient(to right,#3b82f6,#2563eb);color:white; | |
| font-weight:600;border-radius:8px;font-size:0.9rem; | |
| font-family:Inter,sans-serif;text-decoration:none; | |
| box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;" | |
| onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'"> | |
| 💻 GitHub Repo | |
| </a> | |
| </div> | |
| <!-- ───────────── Guidelines for Use ───────────── --> | |
| <div class="card" style="margin-bottom:24px"> | |
| <h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for Use</h2> | |
| <ul style="margin-left:18px;line-height:1.55;list-style:decimal;"> | |
| <li><strong>Convert protein structure into a structure-aware sequence:</strong> | |
| Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware | |
| sequence will be generated using | |
| <a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>, | |
| based on 3D structures from | |
| <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a> or the | |
| <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li> | |
| <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong> | |
| you must first visit the | |
| <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a> | |
| or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a> | |
| to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li> | |
| <li><strong>Drug input supports both SELFIES and SMILES:</strong><br> | |
| You can enter a SELFIES string directly, or paste a SMILES string. | |
| SMILES will be automatically converted to SELFIES using | |
| <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>. | |
| If conversion fails, a red error message will be displayed.</li> | |
| <li>Optionally enter a <strong>1-based</strong> drug atom or substructure index | |
| to highlight the Top-10 interacting protein residues.</li> | |
| <li>After inference, you can use the | |
| “Download PDF” link to export a high-resolution vector version.</li> | |
| </ul> | |
| </div> | |
| <div class="card"> | |
| <form method="POST" enctype="multipart/form-data" class="grid"> | |
| <div><label>Protein Structure (.pdb / .cif)</label> | |
| <input type="file" name="structure_file"> | |
| <input type="hidden" name="tmp_structure_path" value="{{ tmp_structure_path }}"></div> | |
| <div><label>Protein Sequence</label> | |
| <textarea name="protein_sequence" placeholder="Confirm / paste sequence…">{{ protein_seq }}</textarea></div> | |
| <div><label>Drug Sequence (SELFIES/SMILES)</label> | |
| <textarea name="drug_sequence" placeholder="[C][C][O]/cco …">{{ drug_seq }}</textarea></div> | |
| <label>Drug atom/substructure index (1-based) – show Top-10 related protein residue</label> | |
| <input type="number" name="drug_idx" min="1" style="width:120px"> | |
| <div class="grid grid-2"> | |
| <button class="btn btn-primary" type="Inference" name="confirm_structure">Confirm Structure</button> | |
| <button class="btn btn-primary" type="Inference" name="Inference">Inference</button> | |
| </div> | |
| <button class="btn btn-neutral" style="width:100%" type="Inference" name="clear">Clear</button> | |
| </form> | |
| {% if structure_seq %} | |
| <div style="margin-top:18px"><strong>Structure-aware sequence:</strong><pre>{{ structure_seq }}</pre></div> | |
| {% endif %} | |
| {% if result_html %} | |
| <div class="vis-box" style="margin-top:26px">{{ result_html|safe }}</div> | |
| {% endif %} | |
| </div></body></html> | |
| """, | |
| protein_seq=protein_seq, drug_seq=drug_seq, structure_seq=structure_seq, | |
| result_html=result_html, tmp_structure_path=tmp_structure_path) | |
| # ───── run ───────────────────────────────────────────────────── | |
| if __name__ == "__main__": | |
| app.run(debug=True, host="0.0.0.0", port=7860) |