File size: 25,716 Bytes
b7ce511
 
 
 
3df029f
 
b7ce511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d9281c
b7ce511
 
 
 
 
 
 
 
44cf989
 
b7ce511
 
5676c75
 
 
 
 
 
44cf989
b7ce511
 
 
 
 
91a8b2f
5676c75
fe43b7a
b7ce511
 
 
 
3df029f
b7ce511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3df029f
b7ce511
fe43b7a
91a8b2f
5676c75
 
 
b7ce511
5676c75
 
 
 
2881d33
5676c75
fe43b7a
91a8b2f
5676c75
fe43b7a
b7ce511
91a8b2f
5676c75
b7ce511
5676c75
b7ce511
91a8b2f
5676c75
91a8b2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe43b7a
91a8b2f
 
 
 
 
 
5676c75
 
 
 
 
 
 
 
 
 
 
 
 
b7ce511
3df029f
 
 
 
 
 
 
 
 
 
 
 
 
5676c75
 
3df029f
5676c75
 
 
 
 
 
 
 
 
 
3df029f
5676c75
 
 
 
3df029f
5676c75
 
 
 
 
 
3df029f
 
 
 
 
 
5676c75
e4d0f40
3df029f
5676c75
3df029f
 
 
5676c75
3df029f
 
 
 
5676c75
3df029f
 
 
 
 
5676c75
 
 
3df029f
 
 
 
5676c75
3df029f
5676c75
 
 
3df029f
 
5676c75
 
3df029f
5676c75
3df029f
5676c75
3df029f
 
5676c75
 
 
3df029f
5676c75
 
 
 
3df029f
5676c75
3df029f
5676c75
 
 
b7ce511
5676c75
 
 
 
 
 
 
3df029f
b7ce511
3df029f
 
 
b7ce511
 
3df029f
b7ce511
 
5676c75
 
3df029f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5676c75
b7ce511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3df029f
b7ce511
3df029f
b7ce511
 
 
3df029f
 
 
 
 
 
 
 
 
b7ce511
3df029f
 
 
 
 
 
 
 
 
b7ce511
3df029f
 
 
 
 
 
 
 
 
 
 
 
 
 
b7ce511
3df029f
 
 
 
 
 
 
 
 
 
fad71d9
3df029f
 
 
 
 
 
fad71d9
3df029f
 
 
 
 
 
 
 
 
 
 
 
fad71d9
3df029f
 
 
 
 
 
 
 
b7ce511
 
 
3df029f
0d9281c
3df029f
 
b7ce511
3df029f
0d9281c
3df029f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d9281c
b7ce511
 
 
3df029f
 
 
b7ce511
 
 
 
 
 
 
 
3df029f
b7ce511
 
3df029f
 
 
 
b7ce511
 
3df029f
e4d0f40
3df029f
b7ce511
 
3df029f
 
 
b7ce511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3df029f
 
b7ce511
3df029f
b7ce511
3df029f
b7ce511
 
 
 
 
 
 
 
 
 
 
 
 
 
3df029f
b7ce511
 
 
 
5676c75
fad71d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
# ─── monkey-patch gradio_client so bool schemas don’t crash json_schema_to_python_type ───
import gradio_client.utils as _gc_utils

# back up originals
_orig_get_type   = _gc_utils.get_type
_orig_json2py    = _gc_utils._json_schema_to_python_type

def _patched_get_type(schema):
    # treat any boolean schema as if it were an empty dict
    if isinstance(schema, bool):
        schema = {}
    return _orig_get_type(schema)

def _patched_json_schema_to_python_type(schema, defs=None):
    # treat any boolean schema as if it were an empty dict
    if isinstance(schema, bool):
        schema = {}
    return _orig_json2py(schema, defs)

_gc_utils.get_type                    = _patched_get_type
_gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type

# ─── now it’s safe to import Gradio and build your interface ───────────────────────────
import gradio as gr
from gradio.themes import Soft

import os
import sys
import argparse
import tempfile
import shutil
import base64
import io

import torch
import selfies
from rdkit import Chem
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import cm
from typing import Optional

from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
from torch.utils.data import DataLoader
from Bio.PDB import PDBParser, MMCIFParser
from Bio.Data import IUPACData

from utils.drug_tokenizer import DrugTokenizer
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
from utils.foldseek_util import get_struc_seq

# ───── Helpers ─────────────────────────────────────────────────

three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})

def simple_seq_from_structure(path: str) -> str:
    parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
    structure = parser.get_structure("P", path)
    chains = list(structure.get_chains())
    if not chains:
        return ""
    chain = max(chains, key=lambda c: len(list(c.get_residues())))
    return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)

def smiles_to_selfies(smiles: str) -> Optional[str]:
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        return selfies.encoder(smiles)
    except Exception:
        return None

def parse_config():
    p = argparse.ArgumentParser()
    p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
    p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
    p.add_argument("--agg_mode", type=str, default="mean_all_tok")
    p.add_argument("--group_size", type=int, default=1)
    p.add_argument("--fusion", default="CAN")
    p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--save_path_prefix", default="save_model_ckp/")
    p.add_argument("--dataset", default="Human")
    return p.parse_args()

args = parse_config()
DEVICE = args.device

# ───── Load models & tokenizers ─────────────────────────────────
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
prot_model     = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
drug_tokenizer = DrugTokenizer()
drug_model     = AutoModel.from_pretrained(args.drug_encoder_path)
encoding       = Pre_encoded(prot_model, drug_model, args).to(DEVICE)

def collate_fn(batch):
    query1, query2, scores = zip(*batch)
    
    query_encodings1 = prot_tokenizer.batch_encode_plus(
        list(query1),
        max_length=512,
        padding="max_length",
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    query_encodings2 = drug_tokenizer.batch_encode_plus(
        list(query2),
        max_length=512,
        padding="max_length",
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    scores = torch.tensor(list(scores))

    attention_mask1 = query_encodings1["attention_mask"].bool()
    attention_mask2 = query_encodings2["attention_mask"].bool()

    return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores


def get_case_feature(model, loader):
    model.eval()
    with torch.no_grad():
        for p_ids, p_mask, d_ids, d_mask, _ in loader:
            p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
            d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
            p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
            return [(p_emb.cpu(), d_emb.cpu(),
                     p_ids.cpu(), d_ids.cpu(),
                     p_mask.cpu(), d_mask.cpu(), None)]

# ─────────────── visualisation ───────────────────────────────────────────
def _safe_is_special(tokenizer, tok: str) -> bool:
    # Some tokenisers expose different special token sets; fall back conservatively.
    special_sets = []
    if hasattr(tokenizer, "all_special_tokens"):
        special_sets.append(set(tokenizer.all_special_tokens))
    if hasattr(tokenizer, "special_tokens_map"):
        special_sets.extend(set(v) if isinstance(v, list) else {v}
                            for v in tokenizer.special_tokens_map.values())
    for s in special_sets:
        if tok in s:
            return True
    return False

def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
    """
    Render a Protein β†’ Drug cross-attention heat-map and optional Top-30 residue table.
    """
    model.eval()
    with torch.no_grad():
        # ── unpack single-case tensors ───────────────────────────────────────────
        p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
        p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
        p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)

        # ── forward pass: Protein β†’ Drug attention (B, n_p, n_d) ───────────────
        _, att_pd = model(p_emb, d_emb, p_mask, d_mask)
        attn = att_pd.squeeze(0).cpu()  # (n_p, n_d)

        # ── decode tokens (skip special symbols) ────────────────────────────────
        def clean_ids(ids, tokenizer):
            toks = tokenizer.convert_ids_to_tokens(ids.tolist())
            return [t for t in toks if not _safe_is_special(tokenizer, t)]

        p_tokens_full  = clean_ids(p_ids[0],  prot_tokenizer)
        p_indices_full = list(range(1, len(p_tokens_full)  + 1))
        d_tokens_full  = clean_ids(d_ids[0],  drug_tokenizer)
        d_indices_full = list(range(1, len(d_tokens_full)  + 1))

        # ── safety cut-off to match attn mat size ──────────────────────────────
        p_tokens = p_tokens_full[: attn.size(0)]
        p_indices = p_indices_full[: attn.size(0)]
        d_tokens = d_tokens_full[: attn.size(1)]
        d_indices = d_indices_full[: attn.size(1)]
        attn = attn[: len(p_tokens), : len(d_tokens)]

        orig_attn = attn.clone()

        # ── adaptive sparsity pruning ───────────────────────────────────────────
        thr = attn.max().item() * 0.05 if attn.numel() > 0 else 0.0
        row_keep = (attn.max(dim=1).values > thr) if attn.size(0) else torch.tensor([], dtype=torch.bool)
        col_keep = (attn.max(dim=0).values > thr) if attn.size(1) else torch.tensor([], dtype=torch.bool)

        if row_keep.sum().item() < 3 and attn.size(0) > 0:
            row_keep = torch.ones(attn.size(0), dtype=torch.bool)
        if col_keep.sum().item() < 3 and attn.size(1) > 0:
            col_keep = torch.ones(attn.size(1), dtype=torch.bool)

        attn      = attn[row_keep][:, col_keep]
        p_tokens  = [tok for keep, tok in zip(row_keep.tolist(), p_tokens)       if keep]
        p_indices = [idx for keep, idx in zip(row_keep.tolist(), p_indices)      if keep]
        d_tokens  = [tok for keep, tok in zip(col_keep.tolist(), d_tokens)       if keep]
        d_indices = [idx for keep, idx in zip(col_keep.tolist(), d_indices)      if keep]

        # ── cap column count at 150 for readability ─────────────────────────────
        if attn.size(1) > 150:
            topc = torch.topk(attn.sum(0), k=150).indices
            attn = attn[:, topc]
            d_tokens  = [d_tokens[i]  for i in topc]
            d_indices = [d_indices[i] for i in topc]

        # ── draw heat-map ──────────────────────────────────────────────────────
        x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
        y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]

        fig_w = min(22, max(8, len(x_labels) * 0.6))
        fig_h = min(24, max(6, len(y_labels) * 0.8))

        fig, ax = plt.subplots(figsize=(fig_w, fig_h))
        im = ax.imshow(attn.numpy(), aspect="auto", cmap=cm.viridis, interpolation="nearest")

        ax.set_title("Protein β†’ Drug Attention", pad=8, fontsize=11)
        ax.set_xticks(range(len(x_labels)))
        ax.set_xticklabels(x_labels, rotation=90, fontsize=8, ha="center", va="center")
        ax.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False, pad=27)

        ax.set_yticks(range(len(y_labels)))
        ax.set_yticklabels(y_labels, fontsize=7)
        ax.tick_params(axis="y", top=True, bottom=False, labeltop=True, labelbottom=False, pad=10)

        fig.colorbar(im, fraction=0.026, pad=0.01)
        fig.tight_layout()

        # build PNG / PDF
        buf_png = io.BytesIO()
        fig.savefig(buf_png, format="png", dpi=140)
        buf_png.seek(0)

        buf_pdf = io.BytesIO()
        fig.savefig(buf_pdf, format="pdf")
        buf_pdf.seek(0)
        plt.close(fig)

        png_b64 = base64.b64encode(buf_png.getvalue()).decode()
        pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()

        html_heat = (
            f"<div class='heatmap-card' style='position: relative; width: 100%;'>"
              f"<a href='data:application/pdf;base64,{pdf_b64}' download='attention_heatmap.pdf' "
                 "style='position:absolute; top:12px; right:12px; "
                 "background: var(--primary); color:#fff; padding:8px 16px; border-radius:8px; "
                 "font-size:.92rem; font-weight:600; text-decoration:none;'>Download PDF</a>"
              f"<a href='data:image/png;base64,{png_b64}' target='_blank' title='Click to enlarge'>"
                f"<img src='data:image/png;base64,{png_b64}' "
                   "style='display:block; width:100%; height:auto; cursor:zoom-in;'/>"
              "</a>"
            "</div>"
        )

        # ───────────────────── Top-30 table (optional) ─────────────────────
        table_html = ""
        if drug_idx is not None and orig_attn.size(1) > 0 and 0 <= drug_idx < orig_attn.size(1):
            # map original 0-based drug_idx β†’ pruned column
            col_pos = None
            if (drug_idx + 1) in d_indices:
                col_pos = d_indices.index(drug_idx + 1)
            elif 0 <= drug_idx < len(d_tokens):
                col_pos = drug_idx

            if col_pos is not None:
                col_vec = attn[:, col_pos]
                k = min(30, len(col_vec))
                if k > 0:
                    topk = torch.topk(col_vec, k=k).indices.tolist()

                    # header cells
                    header_cells = (
                        "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Rank</th>"
                        + "".join(
                            f"<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center'>{r+1}</th>"
                            for r in range(len(topk))
                        )
                    )
                    residue_cells = (
                        "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Residue</th>"
                        + "".join(
                            f"<td style='border:1px solid #e5e7eb; padding:6px; text-align:center'>{p_tokens[i]}</td>"
                            for i in topk
                        )
                    )
                    position_cells = (
                        "<th style='border:1px solid #e5e7eb; padding:6px; background:#f8fafc; text-align:center;'>Position</th>"
                        + "".join(
                            f"<td style='border:1px solid #e5e7eb; padding:6px; text-align:center'>{p_indices[i]}</td>"
                            for i in topk
                        )
                    )

                    drug_tok_text = d_tokens[col_pos]
                    orig_idx_disp = d_indices[col_pos]

                    table_html = (
                        f"<div class='card' style='margin-top:18px'>"
                        f"<h4 style='margin:0 0 12px; font-size:1rem;'>"
                        f"Drug atom #{orig_idx_disp} <code>{drug_tok_text}</code> β†’ Top-30 Protein residues"
                        f"</h4>"
                        f"<table style='border-collapse:collapse; margin:0 auto 4px; font-size:.95rem'>"
                        f"<tr>{header_cells}</tr>"
                        f"<tr>{residue_cells}</tr>"
                        f"<tr>{position_cells}</tr>"
                        f"</table>"
                        f"</div>"
                    )

        return table_html + html_heat

# ───── Gradio Callbacks ─────────────────────────────────────────

ROOT = os.path.dirname(os.path.abspath(__file__))
FOLDSEEK_BIN = os.path.join(ROOT, "bin", "foldseek")

def extract_sequence_cb(structure_file):
    if structure_file is None or not os.path.exists(structure_file.name):
        return ""
    parsed = get_struc_seq(FOLDSEEK_BIN, structure_file.name, None, plddt_mask=False)
    first_chain = next(iter(parsed))
    _, _, struct_seq = parsed[first_chain]
    return struct_seq

def inference_cb(prot_seq, drug_seq, atom_idx):
    if not prot_seq:
        return "<p style='color:red'>Please extract or enter a protein sequence first.</p>"
    if not drug_seq.strip():
        return "<p style='color:red'>Please enter a drug sequence.</p>"
    if not drug_seq.strip().startswith("["):
        conv = smiles_to_selfies(drug_seq.strip())
        if conv is None:
            return "<p style='color:red'>SMILES→SELFIES conversion failed.</p>"
        drug_seq = conv
    loader = DataLoader([(prot_seq, drug_seq, 1)], batch_size=1, collate_fn=collate_fn)
    feats = get_case_feature(encoding, loader)
    model = FusionDTI(446, 768, args).to(DEVICE)
    ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}", "best_model.ckpt")
    if os.path.isfile(ckpt):
        model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
    return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None)

def clear_cb():
    return "", "", None, "", None

# ───── Theme & CSS ─────────────────────────────────────────────

css = """
:root {
  --bg:#f7f7fb;
  --card:#ffffff;
  --border:#e6e7eb;
  --primary:#4f46e5;
  --primary-dark:#4338ca;
  --text:#0f172a;
  --muted:#6b7280;
  --radius:14px;
  --shadow:0 10px 30px rgba(15,23,42,.06);
}
*{box-sizing:border-box}
html,body{background:var(--bg)!important;color:var(--text)!important;font-family:Inter,system-ui,Arial,sans-serif}
h1{font-weight:700;font-size:32px;margin:22px 0 10px;text-align:center;letter-spacing:.2px}
p,li,button,.gr-button,label,.gr-text{font-size:14px}

/* Cards */
.card{
  background:var(--card); border:1px solid var(--border); border-radius:var(--radius);
  box-shadow:var(--shadow); padding:24px; max-width:1100px; margin:0 auto 28px;
}

/* Project links */
.link-btn{
  display:inline-flex;               /* icon + text centred vertically */
  align-items:center;
  justify-content:center;
  margin:0 8px;
  padding:10px 18px;
  border-radius:10px;
  color:#fff;
  font-weight:650;
  text-decoration:none;
  box-shadow:0 6px 18px rgba(79,70,229,.18);
  transition:transform .12s ease,filter .12s ease;
}
.link-btn:hover{transform:translateY(-1px);filter:brightness(1.03)}
.link-btn svg{margin-right:6px;vertical-align:middle}
.link-btn.project{background:linear-gradient(135deg,#10b981,#059669)}
.link-btn.arxiv  {background:linear-gradient(135deg,#ef4444,#dc2626)}
.link-btn.github {background:linear-gradient(135deg,#3b82f6,#2563eb)}

/* Labels & inputs */
#input-card label{font-weight:650!important;color:var(--text)!important}
textarea, input, .gr-textbox, .gr-number{
  border-radius:12px!important; border:1px solid var(--border)!important;
}
#input-card .gr-row, #input-card .gr-cols{gap:16px}

/* Buttons */
.gr-button{min-height:42px!important; padding:0 18px!important; border-radius:12px!important; font-weight:700!important}
.gr-button.primary, .gr-button-primary{
  background:var(--primary)!important; border-color:var(--primary)!important; color:#fff!important
}
.gr-button.primary:hover, .gr-button-primary:hover{background:var(--primary-dark)!important;border-color:var(--primary-dark)!important}

/* Action buttons row */
#action-buttons{gap:12px}
#extract-btn, #inference-btn{flex:1 1 260px!important; min-width:180px!important}
#clear-btn{width:100%!important}

/* Output */
#output-card{padding-top:0}
#result-html{padding:0; margin:0}
#result-html .heatmap-card{
  background:var(--card); border:1px solid var(--border); border-radius:12px; padding:12px; box-shadow:var(--shadow)
}

/* Guidance */
#guidelines-card h2{font-size:18px;margin-bottom:14px;text-align:center}
#guidelines-card ul{margin-left:18px;line-height:1.6}

/* Small screens */
@media (max-width: 900px){
  .card{margin:0 12px 24px}
}
"""

# ───── Gradio Interface Definition ───────────────────────────────
with gr.Blocks(theme=Soft(primary_hue="indigo", neutral_hue="slate"), css=css) as demo:
    # ───────────── Title ─────────────
    gr.Markdown("<h1 style='text-align: center;'>Token-level Visualiser for Drug-Target Interaction</h1>")

    # ───────────── Project Links (SVG icons) ─────────────
    gr.HTML("""
        <div style="text-align:center;margin-bottom:32px;">
          <a class="link-btn project" href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank" rel="noopener noreferrer" aria-label="Project Page">
            <!-- globe icon -->
            <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
              <path d="M12 2a10 10 0 1 0 10 10A10.012 10.012 0 0 0 12 2Zm7.93 9h-3.18a15.84 15.84 0 0 0-1.19-5.02A8.02 8.02 0 0 1 19.93 11ZM12 4c.86 0 2.25 1.86 3.01 6H8.99C9.75 5.86 11.14 4 12 4ZM4.07 13h3.18c.2 1.79.66 3.47 1.19 5.02A8.02 8.02 0 0 1 4.07 13Zm3.18-2H4.07A8.02 8.02 0 0 1 8.44 5.98 15.84 15.84 0 0 0 7.25 11Zm1.37 2h6.76c-.76 4.14-2.15 6-3.01 6s-2.25-1.86-3.01-6Zm9.05 0h3.18a8.02 8.02 0 0 1-4.37 5.02 15.84 15.84 0 0 0 1.19-5.02Z"/>
            </svg>
            Project Page
          </a>
          <a class="link-btn arxiv" href="https://arxiv.org/abs/2406.01651" target="_blank" rel="noopener noreferrer" aria-label="ArXiv: 2406.01651">
            <!-- arXiv-like paper icon -->
            <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
              <path d="M6 2h9l5 5v13a2 2 0 0 1-2 2H6a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2Zm8 1.5V8h4.5L14 3.5ZM7 12h10v2H7v-2Zm0 4h10v2H7v-2Zm0-8h6v2H7V8Z"/>
            </svg>
            ArXiv: 2406.01651
          </a>
          <a class="link-btn github" href="https://github.com/ZhaohanM/FusionDTI" target="_blank" rel="noopener noreferrer" aria-label="GitHub Repo">
            <!-- GitHub mark -->
            <svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="currentColor" aria-hidden="true">
              <path d="M12 .5A12 12 0 0 0 0 12.76c0 5.4 3.44 9.98 8.2 11.6.6.12.82-.28.82-.6v-2.3c-3.34.74-4.04-1.44-4.04-1.44-.54-1.38-1.32-1.74-1.32-1.74-1.08-.76.08-.74.08-.74 1.2.08 1.84 1.26 1.84 1.26 1.06 1.86 2.78 1.32 3.46 1.02.1-.8.42-1.32.76-1.62-2.66-.32-5.46-1.36-5.46-6.02 0-1.34.46-2.44 1.22-3.3-.12-.32-.54-1.64.12-3.42 0 0 1-.34 3.32 1.26.96-.28 1.98-.42 3-.42s2.04.14 3 .42c2.32-1.6 3.32-1.26 3.32-1.26.66 1.78.24 3.1.12 3.42.76.86 1.22 1.96 1.22 3.3 0 4.68-2.8 5.68-5.48 6 .44.38.84 1.12.84 2.28v3.38c0 .32.22.74.84.6A12.02 12.02 0 0 0 24 12.76 12 12 0 0 0 12 .5Z"/>
            </svg>
            GitHub Repo
          </a>
        </div>
        """)

    # ───────────── Guidelines Card ─────────────
    gr.HTML(
        """
        <div class="card" id="guidelines-card" style="margin-bottom:24px">
          <h2>Guidelines for Users</h2>
          <ul style="list-style:decimal;">
            <li><strong>Convert protein structure into a structure-aware sequence:</strong>
                Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
                sequence will be generated using
                <a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
                based on 3D structures from
                <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
                <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
            <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
                please first visit the
                <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
                or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
                to download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
            <li><strong>Drug input supports both SELFIES and SMILES:</strong>
                Enter a SELFIES string directly, or paste a SMILES string. SMILES will
                be converted to SELFIES using the
                <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
                If conversion fails, a red error message will be displayed.</li>
            <li>Optionally enter a <strong>1-based</strong> drug atom/substructure index
                to highlight the Top-30 interacting protein residues.</li>
            <li>After inference, use β€œDownload PDF” to export a high-resolution vector figure.</li>
          </ul>
        </div>
        """
    )

    # ───────────── Input Card ─────────────
    with gr.Column(elem_id="input-card", elem_classes="card"):
        protein_seq = gr.Textbox(
            label="Protein Structure-aware Sequence",
            lines=3,
            elem_id="protein-seq"
        )
        drug_seq = gr.Textbox(
            label="Drug Sequence (SELFIES/SMILES)",
            lines=3,
            elem_id="drug-seq"
        )
        structure_file = gr.File(
            label="Upload Protein Structure (.pdb/.cif)",
            file_types=[".pdb", ".cif"],
            elem_id="structure-file"
        )
        drug_idx = gr.Number(
            label="Drug atom/substructure index (1-based)",
            value=None,
            precision=0,
            elem_id="drug-idx"
        )

    # ───────────── Action Buttons ─────────────
    with gr.Row(elem_id="action-buttons", equal_height=True):
        btn_extract = gr.Button("Extract sequence", variant="primary", elem_id="extract-btn")
        btn_infer   = gr.Button("Inference",        variant="primary", elem_id="inference-btn")
    with gr.Row():
        clear_btn   = gr.Button("Clear", variant="secondary", elem_id="clear-btn")

    # ───────────── Output Visualisation ─────────────
    output_html  = gr.HTML(elem_id="result-html")

    # ───────────── Event Wiring ─────────────
    btn_extract.click(
        fn=extract_sequence_cb,
        inputs=[structure_file],
        outputs=[protein_seq]
    )
    btn_infer.click(
        fn=inference_cb,
        inputs=[protein_seq, drug_seq, drug_idx],
        outputs=[output_html]
    )
    clear_btn.click(
        fn=clear_cb,
        inputs=[],
        outputs=[protein_seq, drug_seq, drug_idx, output_html, structure_file]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)