github-actions[bot] commited on
Commit
06c3276
·
1 Parent(s): 2bc66d4

Auto-sync from demo at Mon Dec 1 10:51:45 UTC 2025

Browse files
graphgen/configs/search_dna_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pipeline:
2
+ - name: read_step
3
+ op_key: read
4
+ params:
5
+ input_file: resources/input_examples/search_dna_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
+
7
+ - name: search_step
8
+ op_key: search
9
+ deps: [read_step] # search_step depends on read_step
10
+ params:
11
+ data_sources: [ncbi] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
12
+ ncbi_params:
13
+ email: test@example.com # NCBI requires an email address
14
+ tool: GraphGen # tool name for NCBI API
15
+ use_local_blast: true # whether to use local blast for DNA search
16
+ local_blast_db: /your_path/refseq_241 # path to local BLAST database (without .nhr extension)
17
+
graphgen/configs/{search_config.yaml → search_protein_config.yaml} RENAMED
@@ -2,7 +2,7 @@ pipeline:
2
  - name: read_step
3
  op_key: read
4
  params:
5
- input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
 
7
  - name: search_step
8
  op_key: search
@@ -11,4 +11,5 @@ pipeline:
11
  data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot
12
  uniprot_params:
13
  use_local_blast: true # whether to use local blast for uniprot search
14
- local_blast_db: /your_path/uniprot_sprot
 
 
2
  - name: read_step
3
  op_key: read
4
  params:
5
+ input_file: resources/input_examples/search_protein_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
 
7
  - name: search_step
8
  op_key: search
 
11
  data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot
12
  uniprot_params:
13
  use_local_blast: true # whether to use local blast for uniprot search
14
+ local_blast_db: /your_path/2024_01/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot
15
+ # options: uniprot_sprot (recommended, high quality), uniprot_trembl, or uniprot_${RELEASE} (merged database)
graphgen/configs/search_rna_config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pipeline:
2
+ - name: read_step
3
+ op_key: read
4
+ params:
5
+ input_file: resources/input_examples/search_rna_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
+
7
+ - name: search_step
8
+ op_key: search
9
+ deps: [read_step] # search_step depends on read_step
10
+ params:
11
+ data_sources: [rnacentral] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
12
+ rnacentral_params:
13
+ use_local_blast: true # whether to use local blast for RNA search
14
+ local_blast_db: /your_path/refseq_rna_241 # format: /path/to/refseq_rna_${RELEASE}
15
+ # can also use DNA database with RNA sequences (if already built)
16
+
graphgen/graphgen.py CHANGED
@@ -45,7 +45,7 @@ class GraphGen:
45
 
46
  # llm
47
  self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer(
48
- model_name=os.getenv("TOKENIZER_MODEL")
49
  )
50
 
51
  self.synthesizer_llm_client: BaseLLMWrapper = (
 
45
 
46
  # llm
47
  self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer(
48
+ model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base")
49
  )
50
 
51
  self.synthesizer_llm_client: BaseLLMWrapper = (
graphgen/models/__init__.py CHANGED
@@ -26,6 +26,8 @@ from .reader import (
26
  RDFReader,
27
  TXTReader,
28
  )
 
 
29
  from .searcher.db.uniprot_searcher import UniProtSearch
30
  from .searcher.kg.wiki_search import WikiSearch
31
  from .searcher.web.bing_search import BingSearch
 
26
  RDFReader,
27
  TXTReader,
28
  )
29
+ from .searcher.db.ncbi_searcher import NCBISearch
30
+ from .searcher.db.rnacentral_searcher import RNACentralSearch
31
  from .searcher.db.uniprot_searcher import UniProtSearch
32
  from .searcher.kg.wiki_search import WikiSearch
33
  from .searcher.web.bing_search import BingSearch
graphgen/models/searcher/db/ncbi_searcher.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import re
4
+ import subprocess
5
+ import tempfile
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from functools import lru_cache
8
+ from http.client import IncompleteRead
9
+ from typing import Dict, Optional
10
+
11
+ from Bio import Entrez, SeqIO
12
+ from Bio.Blast import NCBIWWW, NCBIXML
13
+ from requests.exceptions import RequestException
14
+ from tenacity import (
15
+ retry,
16
+ retry_if_exception_type,
17
+ stop_after_attempt,
18
+ wait_exponential,
19
+ )
20
+
21
+ from graphgen.bases import BaseSearcher
22
+ from graphgen.utils import logger
23
+
24
+
25
+ @lru_cache(maxsize=None)
26
+ def _get_pool():
27
+ return ThreadPoolExecutor(max_workers=10)
28
+
29
+
30
+ # ensure only one NCBI request at a time
31
+ _ncbi_lock = asyncio.Lock()
32
+
33
+
34
+ class NCBISearch(BaseSearcher):
35
+ """
36
+ NCBI Search client to search DNA/GenBank/Entrez databases.
37
+ 1) Get the gene/DNA by accession number or gene ID.
38
+ 2) Search with keywords or gene names (fuzzy search).
39
+ 3) Search with FASTA sequence (BLAST search for DNA sequences).
40
+
41
+ API Documentation: https://www.ncbi.nlm.nih.gov/home/develop/api/
42
+ Note: NCBI has rate limits (max 3 requests per second), delays are required between requests.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ use_local_blast: bool = False,
48
+ local_blast_db: str = "nt_db",
49
+ email: str = "email@example.com",
50
+ api_key: str = "",
51
+ tool: str = "GraphGen",
52
+ ):
53
+ """
54
+ Initialize the NCBI Search client.
55
+
56
+ Args:
57
+ use_local_blast (bool): Whether to use local BLAST database.
58
+ local_blast_db (str): Path to the local BLAST database.
59
+ email (str): Email address for NCBI API requests.
60
+ api_key (str): API key for NCBI API requests, see https://account.ncbi.nlm.nih.gov/settings/.
61
+ tool (str): Tool name for NCBI API requests.
62
+ """
63
+ super().__init__()
64
+ Entrez.timeout = 60 # 60 seconds timeout
65
+ Entrez.email = email
66
+ Entrez.tool = tool
67
+ if api_key:
68
+ Entrez.api_key = api_key
69
+ Entrez.max_tries = 10 if api_key else 3
70
+ Entrez.sleep_between_tries = 5
71
+ self.use_local_blast = use_local_blast
72
+ self.local_blast_db = local_blast_db
73
+ if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"):
74
+ logger.error("Local BLAST database files not found. Please check the path.")
75
+ self.use_local_blast = False
76
+
77
+ @staticmethod
78
+ def _nested_get(data: dict, *keys, default=None):
79
+ """Safely traverse nested dictionaries."""
80
+ for key in keys:
81
+ if not isinstance(data, dict):
82
+ return default
83
+ data = data.get(key, default)
84
+ return data
85
+
86
+ def _gene_record_to_dict(self, gene_record, gene_id: str) -> dict:
87
+ """
88
+ Convert an Entrez gene record to a dictionary.
89
+ All extraction logic is inlined for maximum clarity and performance.
90
+ """
91
+ if not gene_record:
92
+ raise ValueError("Empty gene record")
93
+
94
+ data = gene_record[0]
95
+ locus = (data.get("Entrezgene_locus") or [{}])[0]
96
+
97
+ # Extract common nested paths once
98
+ gene_ref = self._nested_get(data, "Entrezgene_gene", "Gene-ref", default={})
99
+ biosource = self._nested_get(data, "Entrezgene_source", "BioSource", default={})
100
+
101
+ # Process synonyms
102
+ synonyms_raw = gene_ref.get("Gene-ref_syn", [])
103
+ gene_synonyms = []
104
+ if isinstance(synonyms_raw, list):
105
+ for syn in synonyms_raw:
106
+ gene_synonyms.append(syn.get("Gene-ref_syn_E") if isinstance(syn, dict) else str(syn))
107
+ elif synonyms_raw:
108
+ gene_synonyms.append(str(synonyms_raw))
109
+
110
+ # Extract location info
111
+ label = locus.get("Gene-commentary_label", "")
112
+ chromosome_match = re.search(r"Chromosome\s+(\S+)", str(label)) if label else None
113
+
114
+ seq_interval = self._nested_get(
115
+ locus, "Gene-commentary_seqs", 0, "Seq-loc_int", "Seq-interval", default={}
116
+ )
117
+ genomic_location = (
118
+ f"{seq_interval.get('Seq-interval_from')}-{seq_interval.get('Seq-interval_to')}"
119
+ if seq_interval.get('Seq-interval_from') and seq_interval.get('Seq-interval_to')
120
+ else None
121
+ )
122
+
123
+ # Extract representative accession
124
+ representative_accession = next(
125
+ (
126
+ product.get("Gene-commentary_accession")
127
+ for product in locus.get("Gene-commentary_products", [])
128
+ if product.get("Gene-commentary_type") == "3"
129
+ ),
130
+ None,
131
+ )
132
+
133
+ # Extract function
134
+ function = data.get("Entrezgene_summary") or next(
135
+ (
136
+ comment.get("Gene-commentary_comment")
137
+ for comment in data.get("Entrezgene_comments", [])
138
+ if isinstance(comment, dict)
139
+ and "function" in str(comment.get("Gene-commentary_heading", "")).lower()
140
+ ),
141
+ None,
142
+ )
143
+
144
+ return {
145
+ "molecule_type": "DNA",
146
+ "database": "NCBI",
147
+ "id": gene_id,
148
+ "gene_name": gene_ref.get("Gene-ref_locus", "N/A"),
149
+ "gene_description": gene_ref.get("Gene-ref_desc", "N/A"),
150
+ "organism": self._nested_get(
151
+ biosource, "BioSource_org", "Org-ref", "Org-ref_taxname", default="N/A"
152
+ ),
153
+ "url": f"https://www.ncbi.nlm.nih.gov/gene/{gene_id}",
154
+ "gene_synonyms": gene_synonyms or None,
155
+ "gene_type": {
156
+ "1": "protein-coding",
157
+ "2": "pseudo",
158
+ "3": "rRNA",
159
+ "4": "tRNA",
160
+ "5": "snRNA",
161
+ "6": "ncRNA",
162
+ "7": "other",
163
+ }.get(str(data.get("Entrezgene_type")), f"type_{data.get('Entrezgene_type')}"),
164
+ "chromosome": chromosome_match.group(1) if chromosome_match else None,
165
+ "genomic_location": genomic_location,
166
+ "function": function,
167
+ # Fields from accession-based queries
168
+ "title": None,
169
+ "sequence": None,
170
+ "sequence_length": None,
171
+ "gene_id": gene_id,
172
+ "molecule_type_detail": None,
173
+ "_representative_accession": representative_accession,
174
+ }
175
+
176
+ def get_by_gene_id(self, gene_id: str, preferred_accession: Optional[str] = None) -> Optional[dict]:
177
+ """Get gene information by Gene ID."""
178
+ def _extract_from_genbank(result: dict, accession: str):
179
+ """Enrich result dictionary with sequence and summary information from accession."""
180
+ with Entrez.efetch(db="nuccore", id=accession, rettype="gb", retmode="text") as handle:
181
+ record = SeqIO.read(handle, "genbank")
182
+ result["sequence"] = str(record.seq)
183
+ result["sequence_length"] = len(record.seq)
184
+ result["title"] = record.description
185
+ result["molecule_type_detail"] = (
186
+ "mRNA" if accession.startswith(("NM_", "XM_")) else
187
+ "genomic DNA" if accession.startswith(("NC_", "NT_")) else
188
+ "RNA" if accession.startswith(("NR_", "XR_")) else
189
+ "genomic region" if accession.startswith("NG_") else "N/A"
190
+ )
191
+
192
+ for feature in record.features:
193
+ if feature.type == "source":
194
+ if 'chromosome' in feature.qualifiers:
195
+ result["chromosome"] = feature.qualifiers['chromosome'][0]
196
+
197
+ if feature.location:
198
+ start = int(feature.location.start) + 1
199
+ end = int(feature.location.end)
200
+ result["genomic_location"] = f"{start}-{end}"
201
+
202
+ break
203
+
204
+ if not result.get("organism") and 'organism' in record.annotations:
205
+ result["organism"] = record.annotations['organism']
206
+
207
+ return result
208
+
209
+ try:
210
+ with Entrez.efetch(db="gene", id=gene_id, retmode="xml") as handle:
211
+ gene_record = Entrez.read(handle)
212
+ if not gene_record:
213
+ return None
214
+
215
+ result = self._gene_record_to_dict(gene_record, gene_id)
216
+ if accession := (preferred_accession or result.get("_representative_accession")):
217
+ result = _extract_from_genbank(result, accession)
218
+
219
+ result.pop("_representative_accession", None)
220
+ return result
221
+ except (RequestException, IncompleteRead):
222
+ raise
223
+ except Exception as exc:
224
+ logger.error("Gene ID %s not found: %s", gene_id, exc)
225
+ return None
226
+
227
+ def get_by_accession(self, accession: str) -> Optional[dict]:
228
+ """Get sequence information by accession number."""
229
+ def _extract_gene_id(link_handle):
230
+ """Extract GeneID from elink results."""
231
+ links = Entrez.read(link_handle)
232
+ if not links or "LinkSetDb" not in links[0]:
233
+ return None
234
+
235
+ for link_set in links[0]["LinkSetDb"]:
236
+ if link_set.get("DbTo") != "gene":
237
+ continue
238
+
239
+ link = (link_set.get("Link") or link_set.get("IdList", [{}]))[0]
240
+ return str(link.get("Id") if isinstance(link, dict) else link)
241
+
242
+ try:
243
+ # TODO: support accession number with version number (e.g., NM_000546.3)
244
+ with Entrez.elink(dbfrom="nuccore", db="gene", id=accession) as link_handle:
245
+ gene_id = _extract_gene_id(link_handle)
246
+
247
+ if not gene_id:
248
+ logger.warning("Accession %s has no associated GeneID", accession)
249
+ return None
250
+
251
+ result = self.get_by_gene_id(gene_id, preferred_accession=accession)
252
+ if result:
253
+ result["id"] = accession
254
+ result["url"] = f"https://www.ncbi.nlm.nih.gov/nuccore/{accession}"
255
+ return result
256
+ except (RequestException, IncompleteRead):
257
+ raise
258
+ except Exception as exc:
259
+ logger.error("Accession %s not found: %s", accession, exc)
260
+ return None
261
+
262
+ def get_best_hit(self, keyword: str) -> Optional[dict]:
263
+ """Search NCBI Gene database with a keyword and return the best hit."""
264
+ if not keyword.strip():
265
+ return None
266
+
267
+ try:
268
+ for search_term in [f"{keyword}[Gene] OR {keyword}[All Fields]", keyword]:
269
+ with Entrez.esearch(db="gene", term=search_term, retmax=1, sort="relevance") as search_handle:
270
+ search_results = Entrez.read(search_handle)
271
+ if len(gene_id := search_results.get("IdList", [])) > 0:
272
+ return self.get_by_gene_id(gene_id)
273
+ except (RequestException, IncompleteRead):
274
+ raise
275
+ except Exception as e:
276
+ logger.error("Keyword %s not found: %s", keyword, e)
277
+ return None
278
+
279
+ def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
280
+ """Perform local BLAST search using local BLAST database."""
281
+ try:
282
+ with tempfile.NamedTemporaryFile(mode="w+", suffix=".fa", delete=False) as tmp:
283
+ tmp.write(f">query\n{seq}\n")
284
+ tmp_name = tmp.name
285
+
286
+ cmd = [
287
+ "blastn", "-db", self.local_blast_db, "-query", tmp_name,
288
+ "-evalue", str(threshold), "-max_target_seqs", "1", "-outfmt", "6 sacc"
289
+ ]
290
+ logger.debug("Running local blastn: %s", " ".join(cmd))
291
+ out = subprocess.check_output(cmd, text=True).strip()
292
+ os.remove(tmp_name)
293
+ return out.split("\n", maxsplit=1)[0] if out else None
294
+ except Exception as exc:
295
+ logger.error("Local blastn failed: %s", exc)
296
+ return None
297
+
298
+ def get_by_fasta(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
299
+ """Search NCBI with a DNA sequence using BLAST."""
300
+
301
+ def _extract_and_normalize_sequence(sequence: str) -> Optional[str]:
302
+ """Extract and normalize DNA sequence from input."""
303
+ if sequence.startswith(">"):
304
+ seq = "".join(sequence.strip().split("\n")[1:])
305
+ else:
306
+ seq = sequence.strip().replace(" ", "").replace("\n", "")
307
+ return seq if re.fullmatch(r"[ATCGN]+", seq, re.I) else None
308
+
309
+
310
+ def _process_network_blast_result(blast_record, seq: str, threshold: float) -> Optional[dict]:
311
+ """Process network BLAST result and return dictionary or None."""
312
+ if not blast_record.alignments:
313
+ logger.info("No BLAST hits found for the given sequence.")
314
+ return None
315
+
316
+ best_alignment = blast_record.alignments[0]
317
+ best_hsp = best_alignment.hsps[0]
318
+ if best_hsp.expect > threshold:
319
+ logger.info("No BLAST hits below the threshold E-value.")
320
+ return None
321
+
322
+ hit_id = best_alignment.hit_id
323
+ if accession_match := re.search(r"ref\|([^|]+)", hit_id):
324
+ return self.get_by_accession(accession_match.group(1).split(".")[0])
325
+
326
+ # If unable to extract accession, return basic information
327
+ return {
328
+ "molecule_type": "DNA",
329
+ "database": "NCBI",
330
+ "id": hit_id,
331
+ "title": best_alignment.title,
332
+ "sequence_length": len(seq),
333
+ "e_value": best_hsp.expect,
334
+ "identity": best_hsp.identities / best_hsp.align_length if best_hsp.align_length > 0 else 0,
335
+ "url": f"https://www.ncbi.nlm.nih.gov/nuccore/{hit_id}",
336
+ }
337
+
338
+ try:
339
+ if not (seq := _extract_and_normalize_sequence(sequence)):
340
+ logger.error("Empty or invalid DNA sequence provided.")
341
+ return None
342
+
343
+ # Try local BLAST first if enabled
344
+ if self.use_local_blast and (accession := self._local_blast(seq, threshold)):
345
+ logger.debug("Local BLAST found accession: %s", accession)
346
+ return self.get_by_accession(accession)
347
+
348
+ # Fall back to network BLAST
349
+ logger.debug("Falling back to NCBIWWW.qblast")
350
+
351
+ with NCBIWWW.qblast("blastn", "nr", seq, hitlist_size=1, expect=threshold) as result_handle:
352
+ return _process_network_blast_result(NCBIXML.read(result_handle), seq, threshold)
353
+ except (RequestException, IncompleteRead):
354
+ raise
355
+ except Exception as e:
356
+ logger.error("BLAST search failed: %s", e)
357
+ return None
358
+
359
+ @retry(
360
+ stop=stop_after_attempt(5),
361
+ wait=wait_exponential(multiplier=1, min=4, max=10),
362
+ retry=retry_if_exception_type((RequestException, IncompleteRead)),
363
+ reraise=True,
364
+ )
365
+ async def search(self, query: str, threshold: float = 0.01, **kwargs) -> Optional[Dict]:
366
+ """Search NCBI with either a gene ID, accession number, keyword, or DNA sequence."""
367
+ if not query or not isinstance(query, str):
368
+ logger.error("Empty or non-string input.")
369
+ return None
370
+
371
+ query = query.strip()
372
+ logger.debug("NCBI search query: %s", query)
373
+
374
+ loop = asyncio.get_running_loop()
375
+
376
+ # limit concurrent requests (NCBI rate limit: max 3 requests per second)
377
+ async with _ncbi_lock:
378
+ # Auto-detect query type and execute in thread pool
379
+ if query.startswith(">") or re.fullmatch(r"[ATCGN\s]+", query, re.I):
380
+ result = await loop.run_in_executor(_get_pool(), self.get_by_fasta, query, threshold)
381
+ elif re.fullmatch(r"^\d+$", query):
382
+ result = await loop.run_in_executor(_get_pool(), self.get_by_gene_id, query)
383
+ elif re.fullmatch(r"[A-Z]{2}_\d+\.?\d*", query, re.I):
384
+ result = await loop.run_in_executor(_get_pool(), self.get_by_accession, query)
385
+ else:
386
+ result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query)
387
+
388
+ if result:
389
+ result["_search_query"] = query
390
+ return result
graphgen/models/searcher/db/rnacentral_searcher.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import re
4
+ import subprocess
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from functools import lru_cache
7
+ import tempfile
8
+ from typing import Dict, Optional, List, Any, Set
9
+
10
+ import hashlib
11
+ import requests
12
+ import aiohttp
13
+ from tenacity import (
14
+ retry,
15
+ retry_if_exception_type,
16
+ stop_after_attempt,
17
+ wait_exponential,
18
+ )
19
+
20
+ from graphgen.bases import BaseSearcher
21
+ from graphgen.utils import logger
22
+
23
+
24
+ @lru_cache(maxsize=None)
25
+ def _get_pool():
26
+ return ThreadPoolExecutor(max_workers=10)
27
+
28
+ class RNACentralSearch(BaseSearcher):
29
+ """
30
+ RNAcentral Search client to search RNA databases.
31
+ 1) Get RNA by RNAcentral ID.
32
+ 2) Search with keywords or RNA names (fuzzy search).
33
+ 3) Search with RNA sequence.
34
+
35
+ API Documentation: https://rnacentral.org/api/v1
36
+ """
37
+
38
+ def __init__(self, use_local_blast: bool = False, local_blast_db: str = "rna_db"):
39
+ super().__init__()
40
+ self.base_url = "https://rnacentral.org/api/v1"
41
+ self.headers = {"Accept": "application/json"}
42
+ self.use_local_blast = use_local_blast
43
+ self.local_blast_db = local_blast_db
44
+ if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.nhr"):
45
+ logger.error("Local BLAST database files not found. Please check the path.")
46
+ self.use_local_blast = False
47
+
48
+ @staticmethod
49
+ def _rna_data_to_dict(
50
+ rna_id: str,
51
+ rna_data: Dict[str, Any],
52
+ xrefs_data: Optional[List[Dict[str, Any]]] = None
53
+ ) -> Dict[str, Any]:
54
+ organisms, gene_names, so_terms = set(), set(), set()
55
+ modifications: List[Any] = []
56
+
57
+ for xref in xrefs_data or []:
58
+ acc = xref.get("accession", {})
59
+ if s := acc.get("species"):
60
+ organisms.add(s)
61
+ if g := acc.get("gene", "").strip():
62
+ gene_names.add(g)
63
+ if m := xref.get("modifications"):
64
+ modifications.extend(m)
65
+ if b := acc.get("biotype"):
66
+ so_terms.add(b)
67
+
68
+ def format_unique_values(values: Set[str]) -> Optional[str]:
69
+ if not values:
70
+ return None
71
+ if len(values) == 1:
72
+ return next(iter(values))
73
+ return ", ".join(sorted(values))
74
+
75
+ xrefs_info = {
76
+ "organism": format_unique_values(organisms),
77
+ "gene_name": format_unique_values(gene_names),
78
+ "related_genes": list(gene_names) if gene_names else None,
79
+ "modifications": modifications or None,
80
+ "so_term": format_unique_values(so_terms),
81
+ }
82
+
83
+ fallback_rules = {
84
+ "organism": ["organism", "species"],
85
+ "related_genes": ["related_genes", "genes"],
86
+ "gene_name": ["gene_name", "gene"],
87
+ "so_term": ["so_term"],
88
+ "modifications": ["modifications"],
89
+ }
90
+
91
+ def resolve_field(field_name: str) -> Any:
92
+ if (value := xrefs_info.get(field_name)) is not None:
93
+ return value
94
+
95
+ for key in fallback_rules[field_name]:
96
+ if (value := rna_data.get(key)) is not None:
97
+ return value
98
+
99
+ return None
100
+
101
+ organism = resolve_field("organism")
102
+ gene_name = resolve_field("gene_name")
103
+ so_term = resolve_field("so_term")
104
+ modifications = resolve_field("modifications")
105
+
106
+ related_genes = resolve_field("related_genes")
107
+ if not related_genes and (single_gene := rna_data.get("gene_name")):
108
+ related_genes = [single_gene]
109
+
110
+ sequence = rna_data.get("sequence", "")
111
+
112
+ return {
113
+ "molecule_type": "RNA",
114
+ "database": "RNAcentral",
115
+ "id": rna_id,
116
+ "rnacentral_id": rna_data.get("rnacentral_id", rna_id),
117
+ "sequence": sequence,
118
+ "sequence_length": rna_data.get("length", len(sequence)),
119
+ "rna_type": rna_data.get("rna_type", "N/A"),
120
+ "description": rna_data.get("description", "N/A"),
121
+ "url": f"https://rnacentral.org/rna/{rna_id}",
122
+ "organism": organism,
123
+ "related_genes": related_genes or None,
124
+ "gene_name": gene_name,
125
+ "so_term": so_term,
126
+ "modifications": modifications,
127
+ }
128
+
129
+ @staticmethod
130
+ def _calculate_md5(sequence: str) -> str:
131
+ """
132
+ Calculate MD5 hash for RNA sequence as per RNAcentral spec.
133
+ - Replace U with T
134
+ - Convert to uppercase
135
+ - Encode as ASCII
136
+ """
137
+ # Normalize sequence
138
+ normalized_seq = sequence.replace("U", "T").replace("u", "t").upper()
139
+ if not re.fullmatch(r"[ATCGN]+", normalized_seq):
140
+ raise ValueError(f"Invalid sequence characters after normalization: {normalized_seq[:50]}...")
141
+
142
+ return hashlib.md5(normalized_seq.encode("ascii")).hexdigest()
143
+
144
+ def get_by_rna_id(self, rna_id: str) -> Optional[dict]:
145
+ """
146
+ Get RNA information by RNAcentral ID.
147
+ :param rna_id: RNAcentral ID (e.g., URS0000000001).
148
+ :return: A dictionary containing RNA information or None if not found.
149
+ """
150
+ try:
151
+ url = f"{self.base_url}/rna/{rna_id}"
152
+ url += "?flat=true"
153
+
154
+ resp = requests.get(url, headers=self.headers, timeout=30)
155
+ resp.raise_for_status()
156
+
157
+ rna_data = resp.json()
158
+ xrefs_data = rna_data.get("xrefs", [])
159
+ return self._rna_data_to_dict(rna_id, rna_data, xrefs_data)
160
+ except requests.RequestException as e:
161
+ logger.error("Network error getting RNA ID %s: %s", rna_id, e)
162
+ return None
163
+ except Exception as e: # pylint: disable=broad-except
164
+ logger.error("Unexpected error getting RNA ID %s: %s", rna_id, e)
165
+ return None
166
+
167
+ def get_best_hit(self, keyword: str) -> Optional[dict]:
168
+ """
169
+ Search RNAcentral with a keyword and return the best hit.
170
+ :param keyword: The search keyword (e.g., miRNA name, RNA name).
171
+ :return: Dictionary with RNA information or None.
172
+ """
173
+ keyword = keyword.strip()
174
+ if not keyword:
175
+ logger.warning("Empty keyword provided to get_best_hit")
176
+ return None
177
+
178
+ try:
179
+ url = f"{self.base_url}/rna"
180
+ params = {"search": keyword, "format": "json"}
181
+ resp = requests.get(url, params=params, headers=self.headers, timeout=30)
182
+ resp.raise_for_status()
183
+
184
+ data = resp.json()
185
+ results = data.get("results", [])
186
+
187
+ if not results:
188
+ logger.info("No search results for keyword: %s", keyword)
189
+ return None
190
+
191
+ first_result = results[0]
192
+ rna_id = first_result.get("rnacentral_id")
193
+
194
+ if rna_id:
195
+ detailed = self.get_by_rna_id(rna_id)
196
+ if detailed:
197
+ return detailed
198
+ logger.debug("Using search result data for %s", rna_id or "unknown")
199
+ return self._rna_data_to_dict(rna_id or "", first_result)
200
+
201
+ except requests.RequestException as e:
202
+ logger.error("Network error searching keyword '%s': %s", keyword, e)
203
+ return None
204
+ except Exception as e:
205
+ logger.error("Unexpected error searching keyword '%s': %s", keyword, e)
206
+ return None
207
+
208
+ def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
209
+ """Perform local BLAST search using local BLAST database."""
210
+ try:
211
+ with tempfile.NamedTemporaryFile(mode="w+", suffix=".fa", delete=False) as tmp:
212
+ tmp.write(f">query\n{seq}\n")
213
+ tmp_name = tmp.name
214
+
215
+ cmd = [
216
+ "blastn", "-db", self.local_blast_db, "-query", tmp_name,
217
+ "-evalue", str(threshold), "-max_target_seqs", "1", "-outfmt", "6 sacc"
218
+ ]
219
+ logger.debug("Running local blastn for RNA: %s", " ".join(cmd))
220
+ out = subprocess.check_output(cmd, text=True).strip()
221
+ os.remove(tmp_name)
222
+ return out.split("\n", maxsplit=1)[0] if out else None
223
+ except Exception as exc:
224
+ logger.error("Local blastn failed: %s", exc)
225
+ return None
226
+
227
+ def get_by_fasta(self, sequence: str, threshold: float = 0.01) -> Optional[dict]:
228
+ """
229
+ Search RNAcentral with an RNA sequence.
230
+ Tries local BLAST first if enabled, falls back to RNAcentral API.
231
+ Unified approach: Find RNA ID from sequence search, then call get_by_rna_id() for complete information.
232
+ :param sequence: RNA sequence (FASTA format or raw sequence).
233
+ :param threshold: E-value threshold for BLAST search.
234
+ :return: A dictionary containing complete RNA information or None if not found.
235
+ """
236
+ def _extract_sequence(sequence: str) -> Optional[str]:
237
+ """Extract and normalize RNA sequence from input."""
238
+ if sequence.startswith(">"):
239
+ seq_lines = sequence.strip().split("\n")
240
+ seq = "".join(seq_lines[1:])
241
+ else:
242
+ seq = sequence.strip().replace(" ", "").replace("\n", "")
243
+ return seq if seq and re.fullmatch(r"[AUCGN\s]+", seq, re.I) else None
244
+
245
+ try:
246
+ seq = _extract_sequence(sequence)
247
+ if not seq:
248
+ logger.error("Empty or invalid RNA sequence provided.")
249
+ return None
250
+
251
+ # Try local BLAST first if enabled
252
+ if self.use_local_blast:
253
+ accession = self._local_blast(seq, threshold)
254
+ if accession:
255
+ logger.debug("Local BLAST found accession: %s", accession)
256
+ return self.get_by_rna_id(accession)
257
+
258
+ # Fall back to RNAcentral API if local BLAST didn't find result
259
+ logger.debug("Falling back to RNAcentral API.")
260
+
261
+ md5_hash = self._calculate_md5(seq)
262
+ search_url = f"{self.base_url}/rna"
263
+ params = {"md5": md5_hash, "format": "json"}
264
+
265
+ resp = requests.get(search_url, params=params, headers=self.headers, timeout=60)
266
+ resp.raise_for_status()
267
+
268
+ search_results = resp.json()
269
+ results = search_results.get("results", [])
270
+
271
+ if not results:
272
+ logger.info("No exact match found in RNAcentral for sequence")
273
+ return None
274
+ rna_id = results[0].get("rnacentral_id")
275
+ if not rna_id:
276
+ logger.error("No RNAcentral ID found in search results.")
277
+ return None
278
+ return self.get_by_rna_id(rna_id)
279
+ except Exception as e:
280
+ logger.error("Sequence search failed: %s", e)
281
+ return None
282
+
283
+ @retry(
284
+ stop=stop_after_attempt(3),
285
+ wait=wait_exponential(multiplier=1, min=2, max=10),
286
+ retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError)),
287
+ reraise=True,
288
+ )
289
+ async def search(self, query: str, threshold: float = 0.1, **kwargs) -> Optional[Dict]:
290
+ """Search RNAcentral with either an RNAcentral ID, keyword, or RNA sequence."""
291
+ if not query or not isinstance(query, str):
292
+ logger.error("Empty or non-string input.")
293
+ return None
294
+
295
+ query = query.strip()
296
+ logger.debug("RNAcentral search query: %s", query)
297
+
298
+ loop = asyncio.get_running_loop()
299
+
300
+ # check if RNA sequence (AUCG characters, contains U)
301
+ if query.startswith(">") or (
302
+ re.fullmatch(r"[AUCGN\s]+", query, re.I) and "U" in query.upper()
303
+ ):
304
+ result = await loop.run_in_executor(_get_pool(), self.get_by_fasta, query, threshold)
305
+ # check if RNAcentral ID (typically starts with URS)
306
+ elif re.fullmatch(r"URS\d+", query, re.I):
307
+ result = await loop.run_in_executor(_get_pool(), self.get_by_rna_id, query)
308
+ else:
309
+ # otherwise treat as keyword
310
+ result = await loop.run_in_executor(_get_pool(), self.get_best_hit, query)
311
+
312
+ if result:
313
+ result["_search_query"] = query
314
+ return result
graphgen/operators/search/search_all.py CHANGED
@@ -27,6 +27,10 @@ async def search_all(
27
  data_sources = search_config.get("data_sources", [])
28
 
29
  for data_source in data_sources:
 
 
 
 
30
  if data_source == "uniprot":
31
  from graphgen.models import UniProtSearch
32
 
@@ -34,19 +38,46 @@ async def search_all(
34
  **search_config.get("uniprot_params", {})
35
  )
36
 
37
- data = list(seed_data.values())
38
- data = [d["content"] for d in data if "content" in d]
39
- data = list(set(data)) # Remove duplicates
40
  uniprot_results = await run_concurrent(
41
  uniprot_search_client.search,
42
  data,
43
  desc="Searching UniProt database",
44
  unit="keyword",
45
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  else:
47
  logger.error("Data source %s not supported.", data_source)
48
  continue
49
 
50
- results[data_source] = uniprot_results
51
-
52
  return results
 
27
  data_sources = search_config.get("data_sources", [])
28
 
29
  for data_source in data_sources:
30
+ data = list(seed_data.values())
31
+ data = [d["content"] for d in data if "content" in d]
32
+ data = list(set(data)) # Remove duplicates
33
+
34
  if data_source == "uniprot":
35
  from graphgen.models import UniProtSearch
36
 
 
38
  **search_config.get("uniprot_params", {})
39
  )
40
 
 
 
 
41
  uniprot_results = await run_concurrent(
42
  uniprot_search_client.search,
43
  data,
44
  desc="Searching UniProt database",
45
  unit="keyword",
46
  )
47
+ results[data_source] = uniprot_results
48
+
49
+ elif data_source == "ncbi":
50
+ from graphgen.models import NCBISearch
51
+
52
+ ncbi_search_client = NCBISearch(
53
+ **search_config.get("ncbi_params", {})
54
+ )
55
+
56
+ ncbi_results = await run_concurrent(
57
+ ncbi_search_client.search,
58
+ data,
59
+ desc="Searching NCBI database",
60
+ unit="keyword",
61
+ )
62
+ results[data_source] = ncbi_results
63
+
64
+ elif data_source == "rnacentral":
65
+ from graphgen.models import RNACentralSearch
66
+
67
+ rnacentral_search_client = RNACentralSearch(
68
+ **search_config.get("rnacentral_params", {})
69
+ )
70
+
71
+ rnacentral_results = await run_concurrent(
72
+ rnacentral_search_client.search,
73
+ data,
74
+ desc="Searching RNAcentral database",
75
+ unit="keyword",
76
+ )
77
+ results[data_source] = rnacentral_results
78
+
79
  else:
80
  logger.error("Data source %s not supported.", data_source)
81
  continue
82
 
 
 
83
  return results
requirements.txt CHANGED
@@ -21,6 +21,7 @@ fastapi
21
  trafilatura
22
  aiohttp
23
  diskcache
 
24
 
25
  leidenalg
26
  igraph
 
21
  trafilatura
22
  aiohttp
23
  diskcache
24
+ socksio
25
 
26
  leidenalg
27
  igraph