| import chromadb | |
| from .BaseDB import BaseDB | |
| import random | |
| import string | |
| import os | |
| from tqdm import tqdm | |
| class ChromaDB(BaseDB): | |
| def __init__(self,embedding,save_type = "persistent"): | |
| self.collections = {} | |
| self.embedding = embedding | |
| base_dir = os.path.dirname(os.path.abspath(__file__)) | |
| if save_type == "persistent": | |
| self.path = os.path.join(base_dir, "./chromadb_saves/") | |
| self.client = chromadb.PersistentClient(path = self.path) | |
| else: | |
| self.client = chromadb.Client() | |
| def init_from_data(self, data, db_name): | |
| if db_name in [c.name for c in self.client.list_collections()]: | |
| self.collections[db_name] = self.client.get_collection(name=db_name,embedding_function=self.embedding) | |
| else: | |
| self.collections[db_name] = self.client.create_collection(name=db_name,embedding_function=self.embedding) | |
| if len(data) != 0: | |
| for i in tqdm(list(range(self.collections[db_name].count()+1,len(data)))): | |
| self.collections[db_name].update( | |
| documents=[data[i]], | |
| ids=[str(i)] | |
| ) | |
| return | |
| def search(self, query, n_results, db_name): | |
| if db_name not in self.collections:return [] | |
| n_results = min(self.collections[db_name].count(), n_results) | |
| if n_results < 1: | |
| return [] | |
| results = self.collections[db_name].query(query_texts=[query], n_results=n_results) | |
| return results['documents'][0] | |
| def add(self,text,idx, db_name=""): | |
| if db_name not in self.collections: | |
| self.collections[db_name] = self.client.create_collection( | |
| name=db_name, | |
| embedding_function=self.embedding | |
| ) | |
| self.collections[db_name].add( | |
| documents=[text], | |
| ids=[idx] | |
| ) | |
| return | |
| collection = self.collections[db_name] | |
| existing_doc = collection.get(ids=[idx]) | |
| if existing_doc and existing_doc['ids']: | |
| collection.update( | |
| documents=[text], | |
| ids=[idx] | |
| ) | |
| else: | |
| collection.add( | |
| documents=[text], | |
| ids=[idx] | |
| ) | |
| def delete(self,idx,db_name): | |
| self.collections[db_name].delete(ids=[idx]) | |