Spaces:
Sleeping
Sleeping
| # General | |
| import os | |
| import kagglehub | |
| import pandas as pd | |
| import json | |
| from typing import Literal | |
| from datasets import load_dataset | |
| import random | |
| #Markdown | |
| from IPython.display import Markdown, display, Image | |
| # Image | |
| from PIL import Image | |
| # langchain for llms | |
| from langchain_groq import ChatGroq | |
| # Langchain | |
| from langchain.prompts import PromptTemplate, ChatPromptTemplate | |
| from langchain.output_parsers import StructuredOutputParser, ResponseSchema | |
| from langchain_core.output_parsers import JsonOutputParser | |
| from langchain_core.messages import HumanMessage | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph import END, START, StateGraph, MessagesState | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_core.tools import tool | |
| # Hugging Face | |
| from transformers import AutoModelForImageClassification, AutoProcessor | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| # Extra libraries | |
| from pydantic import BaseModel, Field, model_validator | |
| # Advanced RAG | |
| from langchain_core.documents import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| # ## APIs | |
| os.environ["SERPER_API_KEY"] = os.getenv("SERPER_API_KEY") | |
| os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") | |
| os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") | |
| GROQ_API_KEY = os.environ["GROQ_API_KEY"] | |
| HF_TOKEN = os.environ["HF_TOKEN"] | |
| # ## Setup LLM (Llama 3.3 via Groq) | |
| # Note: Model 3.2 70b is not available on Groq any more | |
| # We will be using 3.3 from Now on | |
| os.environ["GROQ_API_KEY"] = GROQ_API_KEY | |
| #model_3_2 = 'llama-3.2-11b-text-preview' => his model has been removed from Groq platform | |
| model_3_2_small = 'llama-3.1-8b-instant' # Smaller Model 3 Billion parameters if you need speed | |
| model_3_3 ='llama-3.3-70b-versatile' # Very Large and Versatile Model with 70 Billion parameters | |
| llm = ChatGroq( | |
| model= model_3_3, # | |
| temperature=0, | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=2, | |
| # groq_api_key=os.getenv("GROQ_API_KEY") | |
| # other params... | |
| ) | |
| # A test message | |
| # new text: | |
| response = llm.invoke("hi, Please generate 10 unique Dutch names for both male and female?") | |
| response | |
| display(Markdown(response.content)) | |
| # # First Agent: Chatbot Agent | |
| from typing import Annotated | |
| from typing_extensions import TypedDict | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.graph.message import add_messages | |
| class ChatState(TypedDict): | |
| # Messages have the type "list". The `add_messages` function | |
| # in the annotation defines how this state key should be updated | |
| # (in this case, it appends messages to the list, rather than overwriting them) | |
| messages: Annotated[list, add_messages] | |
| chat_graph = StateGraph(ChatState) | |
| def chatbot_agent(state: ChatState): | |
| return {"messages": [llm.invoke(state["messages"])]} | |
| # The first argument is the unique node name | |
| # The second argument is the function or object that will be called whenever | |
| # the node is used. | |
| chat_graph.add_node("chatbot_agent", chatbot_agent) | |
| chat_graph.add_edge(START, "chatbot_agent") | |
| chat_graph.add_edge("chatbot_agent", END) | |
| # Finally, we'll want to be able to run our graph. To do so, call "compile()" | |
| # We basically now give our AI Agent | |
| graph_app = chat_graph.compile() | |
| # Persistent state to maintain conversation history | |
| persistent_state = {"messages": []} # Start with an empty message list | |
| from IPython.display import Image, display | |
| display(Image(graph_app.get_graph(xray=True).draw_mermaid_png())) | |
| from typing import Annotated | |
| from typing_extensions import TypedDict | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.graph.message import add_messages | |
| from IPython.display import display, Markdown | |
| class ChatState(TypedDict): | |
| messages: Annotated[list, add_messages] | |
| chat_graph = StateGraph(ChatState) | |
| def chatbot_agent(state: ChatState): | |
| # Assuming `llm` is your language model that can handle the conversation history | |
| return {"messages": [llm.invoke(state["messages"])]} | |
| chat_graph.add_node("chatbot_agent", chatbot_agent) | |
| chat_graph.add_edge(START, "chatbot_agent") | |
| chat_graph.add_edge("chatbot_agent", END) | |
| graph_app = chat_graph.compile() | |
| # Persistent state to maintain conversation history | |
| persistent_state = {"messages": []} # Start with an empty message list | |
| def stream_graph_updates(user_input: str): | |
| global persistent_state | |
| # Append the user's message to the persistent state | |
| persistent_state["messages"].append(("user", user_input)) | |
| is_finished = False | |
| for event in graph_app.stream(persistent_state): | |
| for value in event.values(): | |
| last_msg = value["messages"][-1] | |
| display(Markdown("Assistant: " + last_msg.content)) | |
| # Append the assistant's response to the persistent state | |
| persistent_state["messages"].append(("assistant", last_msg.content)) | |
| finish_reason = last_msg.response_metadata.get("finish_reason") | |
| if finish_reason == "stop": | |
| is_finished = True | |
| break | |
| if is_finished: | |
| break | |
| while True: | |
| try: | |
| user_input = input('User:') | |
| if user_input.lower() in ["quit", "exit", "q"]: | |
| print("Thank you and Goodbye!") | |
| break | |
| stream_graph_updates(user_input) | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| break | |
| # # Second Agent: Add Search to Chatbot to make it Stronger | |
| from langchain_community.tools import GoogleSerperResults | |
| from typing import List, Annotated | |
| from langchain_core.messages import BaseMessage | |
| from langgraph.prebuilt import ToolNode, create_react_agent | |
| import operator | |
| import functools | |
| class ChatState(TypedDict): | |
| # Messages have the type "list". The `add_messages` function | |
| # in the annotation defines how this state key should be updated | |
| # (in this case, it appends messages to the list, rather than overwriting them) | |
| messages: Annotated[list, add_messages] | |
| def agent_node(state, agent, name): | |
| result = agent.invoke(state) | |
| return { | |
| "messages": [HumanMessage(content=result["messages"][-1].content, name=name)] | |
| } | |
| class SearchState(TypedDict): | |
| # A message is added after each team member finishes | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| # Search Tool | |
| serper_tool = GoogleSerperResults( | |
| num_results=5, | |
| # how many Google results to return | |
| ) | |
| search_agent = create_react_agent(llm, tools=[serper_tool]) | |
| search_node = functools.partial(agent_node, | |
| agent=search_agent, | |
| name="search_agent") | |
| # The first argument is the unique node name | |
| # The second argument is the function or object that will be called whenever | |
| # the node is used. | |
| search_graph = StateGraph(SearchState) | |
| search_graph.add_node("search_agent", search_node) | |
| search_graph.add_edge(START, "search_agent") | |
| search_graph.add_edge("search_agent", END) | |
| # Finally, we'll want to be able to run our graph. To do so, call "compile()" | |
| # We basically now give our AI Agent | |
| search_app = search_graph.compile() | |
| from IPython.display import Image, display | |
| display(Image(search_app.get_graph(xray=True).draw_mermaid_png())) | |
| from langchain_community.tools import GoogleSerperResults | |
| from typing import List, Annotated | |
| from langchain_core.messages import BaseMessage, HumanMessage | |
| from langgraph.prebuilt import ToolNode, create_react_agent | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.graph.message import add_messages | |
| from IPython.display import display, Markdown | |
| import operator | |
| import functools | |
| class ChatState(TypedDict): | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| def agent_node(state, agent, name): | |
| result = agent.invoke(state) | |
| return { | |
| "messages": [HumanMessage(content=result["messages"][-1].content, name=name)] | |
| } | |
| class SearchState(TypedDict): | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| # Search Tool | |
| serper_tool = GoogleSerperResults(num_results=5) # how many Google results to return | |
| search_agent = create_react_agent(llm, tools=[serper_tool]) | |
| search_node = functools.partial(agent_node, agent=search_agent, name="search_agent") | |
| # Create the search graph | |
| search_graph = StateGraph(SearchState) | |
| search_graph.add_node("search_agent", search_node) | |
| search_graph.add_edge(START, "search_agent") | |
| search_graph.add_edge("search_agent", END) | |
| # Compile the search graph | |
| search_app = search_graph.compile() | |
| # Persistent state to maintain conversation history | |
| persistent_state = {"messages": []} # Start with an empty message list | |
| def stream_graph_updates(user_input: str): | |
| global persistent_state | |
| # Append the user's message to the persistent state | |
| persistent_state["messages"].append(HumanMessage(content=user_input)) | |
| # Display "Searching the Web Now..." message | |
| display(Markdown("**Assistant:** Searching the Web Now...")) | |
| is_finished = False | |
| for event in search_app.stream(persistent_state): | |
| for value in event.values(): | |
| last_msg = value["messages"][-1] | |
| display(Markdown("**Assistant:** " + last_msg.content)) | |
| # Append the assistant's response to the persistent state | |
| persistent_state["messages"].append(last_msg) | |
| finish_reason = last_msg.response_metadata.get("finish_reason") | |
| if finish_reason == "stop": | |
| is_finished = True | |
| break | |
| if is_finished: | |
| break | |
| while True: | |
| try: | |
| user_input = input('User:') | |
| if user_input.lower() in ["quit", "exit", "q"]: | |
| print("Thank you and Goodbye!") | |
| break | |
| stream_graph_updates(user_input) | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| break | |
| # # Step 1: Medical Database Preparation | |
| # This step involves preparing and enhancing patient data to be used throughout the simulation. | |
| # ## 1.1 Load Dataset | |
| # ### 1.1.1 Disease Symptoms and Patient Profile Dataset | |
| # Ensure you have downloaded it and placed it in your project directory. | |
| # - https://www.kaggle.com/datasets/uom190346a/disease-symptoms-and-patient-profile-dataset | |
| # Download latest version | |
| path = kagglehub.dataset_download("uom190346a/disease-symptoms-and-patient-profile-dataset") | |
| print("Path to dataset files:", path) | |
| patient_df = pd.read_csv(path+'/Disease_symptom_and_patient_profile_dataset.csv') | |
| patient_df.shape | |
| patient_df.head() | |
| # Calculate the counts of each gender | |
| female_count = patient_df[patient_df['Gender'] == 'Female'].shape[0] | |
| male_count = patient_df[patient_df['Gender'] == 'Male'].shape[0] | |
| # Calculate the ratio | |
| ratio = female_count / male_count | |
| print(f"The ratio of Female to Male is {ratio}:1") | |
| patient_df['Disease'].value_counts().head(20) | |
| # **prepare_medical_dataset Code in One Plalce** | |
| def prepare_medical_dataset(path, file_name): | |
| patient_df = pd.read_csv(path+file_name) | |
| return patient_df | |
| path = kagglehub.dataset_download("uom190346a/disease-symptoms-and-patient-profile-dataset") | |
| file_name = '/Disease_symptom_and_patient_profile_dataset.csv' | |
| patient_df = prepare_medical_dataset(path, file_name) | |
| # ### 1.1.2 Chest X-Ray Images (Pneumonia) | |
| # | |
| # - https://huggingface.co/lxyuan/vit-xray-pneumonia-classification | |
| # - https://huggingface.co/datasets/keremberke/chest-xray-classification | |
| # | |
| # | |
| #from datasets import load_dataset | |
| #patient_x_ray_path = "keremberke/chest-xray-classification" | |
| #x_ray_ds = load_dataset(patient_x_ray_path, name="full") | |
| from datasets import load_dataset | |
| x_ray_ds = load_dataset("keremberke/chest-xray-classification", name="full") | |
| random_index = random.randint(0, x_ray_ds['train'].shape[0] - 1) | |
| patient_x_ray = random_row = x_ray_ds['train'][random_index]['image'] | |
| from datasets import load_dataset | |
| x_ray_ds = load_dataset("keremberke/chest-xray-classification", name="full") | |
| x_ray_ds['train'].shape[0] | |
| # Assuming x_ray_ds['train'] is a dataset where we want to pick a random row | |
| import random | |
| random_index = random.randint(0, x_ray_ds['train'].shape[0] - 1) | |
| patient_x_ray = x_ray_ds['train'][random_index]['image'] | |
| patient_x_ray | |
| type(patient_x_ray) | |
| #!pip install --upgrade accelerate==0.31.0 | |
| #!pip install --upgrade huggingface-hub>=0.23.0 | |
| from transformers import pipeline | |
| # Model in Hugging Face: https://huggingface.co/lxyuan/vit-xray-pneumonia-classification | |
| # vit-xray-pneumonia-classification | |
| classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification") | |
| patient_x_ray_results = classifier(patient_x_ray) | |
| patient_x_ray_results | |
| # Find the label with the highest score | |
| patient_x_ray_label = max(patient_x_ray_results, key=lambda x: x['score'])['label'] | |
| print(patient_x_ray_label) | |
| # Model in Hugging Face: https://huggingface.co/lxyuan/vit-xray-pneumonia-classification | |
| # vit-xray-pneumonia-classification | |
| classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification") | |
| patient_x_ray_results = classifier(patient_x_ray) | |
| # Find the label with the highest score and its score | |
| highest = max(patient_x_ray_results, key=lambda x: x['score']) | |
| highest_score_label = highest['label'] | |
| highest_score = highest['score'] * 100 # Convert to percentage | |
| # Choose the correct verb based on the label | |
| verb = "is" if highest_score_label == "NORMAL" else "has" | |
| # Print the result dynamically | |
| print(f"Patient {verb} {highest_score_label} with Probability of ca. {highest_score:.0f}%") | |
| # ## 1.2 Generate Synthetic Data with LLMs | |
| # Generate culturally appropriate Dutch names and unique alphanumeric IDs for each patient. | |
| # ### 1.2.1 Generate Random Names and IDs for Patience | |
| # This Code Goes Slower because of Llama 3.3 70b being very big and slow LLM | |
| # comparing to llama 3.2 11b | |
| # Switch to model_3_2_smal when running this code | |
| # === Step 1: Define Response Schemas === | |
| # Define the structure of the expected JSON output. | |
| # ResponseSchema for First_Name | |
| first_name_schema = ResponseSchema( | |
| name="First_Name", | |
| description="The first name of the patient." | |
| ) | |
| # ResponseSchema for Last_Name | |
| last_name_schema = ResponseSchema( | |
| name="Last_Name", | |
| description="The last name of the patient." | |
| ) | |
| # ResponseSchema for Patient_ID | |
| patient_id_schema = ResponseSchema( | |
| name="Patient_ID", | |
| description="A unique 13-character alphanumeric patient identifier." | |
| ) | |
| # ResponseSchema for Patient_ID | |
| gender_schema = ResponseSchema( | |
| name="G_Gender", | |
| description="Indicate the first name you generate belong which Gender: Male or Female" | |
| ) | |
| # Aggregate all response schemas | |
| response_schemas = [ | |
| first_name_schema, | |
| last_name_schema, | |
| patient_id_schema, | |
| gender_schema | |
| ] | |
| # === Step 2: Set Up the Output Parser === | |
| # Initialize the StructuredOutputParser with the defined response schemas. | |
| output_parser = StructuredOutputParser.from_response_schemas(response_schemas) | |
| # Get the format instructions to include in the prompt | |
| format_instructions = output_parser.get_format_instructions() | |
| # === Step 3: Craft the Prompt === | |
| # Create a prompt that instructs the LLM to generate only the structured JSON data. | |
| # Define the prompt template using ChatPromptTemplate | |
| prompt_template = ChatPromptTemplate.from_template(""" | |
| you MUST Generate a list of {n} Dutch names along with a unique 13-character alphanumeric Patient_ID for each gender provided. | |
| Always Use {genders} to generate a First_Name which belong to the right Gender, two category is possible: 'Male' or 'Female'. | |
| Ensure the names are culturally appropriate for the Netherlands. | |
| Generate unique names, no repetitions, and ensure diversity. | |
| The ratio of Female to Male is {ratio}:1 | |
| {format_instructions} | |
| Genders: | |
| {genders} | |
| **IMPORTANT:** Do not include any explanations, code, or additional text. | |
| you MUST ALWAYS generate Dutch names and Patient_ID according {format_instructions} | |
| and NEVER return empty values. | |
| YOU MUST Provide only the JSON array as specified. | |
| JSON array Should have exactly {n} rows and 3 columns | |
| """) | |
| # Determine the number of patients | |
| n_patients = len(patient_df) | |
| #n_patients = 120 | |
| # Calculate the counts of each gender | |
| female_count = patient_df[patient_df['Gender'] == 'Female'].shape[0] | |
| male_count = patient_df[patient_df['Gender'] == 'Male'].shape[0] | |
| # Calculate the ratio | |
| ratio = female_count / male_count | |
| # Prepare the list of genders | |
| genders = patient_df['Gender'].tolist() | |
| # === Step 6: Generate the Prompt === | |
| # Format the prompt with the number of patients and their genders. | |
| formatted_prompt = prompt_template.format( | |
| n=n_patients, | |
| ratio = ratio, | |
| genders=', '.join(genders), | |
| format_instructions=format_instructions | |
| ) | |
| # Invoke the model with s Smaller Llama Model for Speed | |
| model_3_2_small = 'llama-3.1-8b-instant' # if you need speed | |
| llm = ChatGroq( | |
| model= model_3_2_small, # | |
| temperature=0, | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=2 | |
| ) | |
| output = llm.invoke(formatted_prompt, timeout=1000) | |
| display(Markdown(output.content)) | |
| output_parser = JsonOutputParser() | |
| json_output = output_parser.invoke(output) | |
| json_output | |
| all_patients = [] | |
| generated_patients = pd.DataFrame(json_output) | |
| generated_patients.head(5) | |
| generated_patients.shape | |
| # Adjusted LLM parameters (if supported) | |
| llm.temperature = 0.9 # Increases randomness | |
| all_patients_name_id = pd.DataFrame() | |
| output_parser = JsonOutputParser() | |
| while all_patients_name_id.shape[0] < n_patients: | |
| output = llm.invoke(formatted_prompt) | |
| json_output = output_parser.invoke(output) | |
| generated_patients = pd.DataFrame(json_output) | |
| all_patients_name_id = pd.concat([generated_patients, all_patients_name_id], axis = 0) | |
| print(f"len all_patients_name_id: {len(all_patients_name_id)}") | |
| all_patients_name_id = all_patients_name_id.drop_duplicates() | |
| print(f"len all_patients_name_id after droping duplicates: {len(all_patients_name_id)}") | |
| all_patients_name_id.rename(columns = {"G_Gender": "Gender"}, inplace= True) | |
| all_patients_name_id.head(10) | |
| gender_counts = patient_df['Gender'].value_counts() | |
| gender_counts | |
| all_patients_name_id['Gender'].value_counts() | |
| # Step 1: Count the number of males and females in patient_df | |
| gender_counts = patient_df['Gender'].value_counts() | |
| # Step 2: Select the required number of unique males and females from all_patients_name_id | |
| unique_males = all_patients_name_id[all_patients_name_id['Gender'] == 'Male'].drop_duplicates().head(gender_counts['Male']) | |
| unique_females = all_patients_name_id[all_patients_name_id['Gender'] == 'Female'].drop_duplicates().head(gender_counts['Female']) | |
| patient_male = patient_df[patient_df['Gender'] == 'Male'].reset_index(drop=True) | |
| patient_female = patient_df[patient_df['Gender'] == 'Female'].reset_index(drop=True) | |
| updated_male_patients = pd.concat([patient_male.reset_index(drop=True), | |
| unique_males[0:patient_male.shape[0]].reset_index(drop=True)], | |
| axis = 1) | |
| updated_female_patients = pd.concat([patient_female.reset_index(drop=True), | |
| unique_females[0:patient_female.shape[0]].reset_index(drop=True)], | |
| axis = 1) | |
| # Step 3: Concatenate patient_df with the selected rows from all_patients_name_id | |
| updated_patient_df = pd.concat([updated_male_patients, updated_female_patients], axis = 0) | |
| updated_patient_df.shape[0] | |
| # Display the final concatenated dataframe | |
| updated_patient_df | |
| updated_patient_df = updated_patient_df.loc[:, ~updated_patient_df.columns.duplicated()] | |
| updated_patient_df | |
| updated_patient_df['Gender'].value_counts() | |
| # #### 1.2.1.1 Select a Random Patient | |
| # Pick a Random Patient: A female between 20 and 29 and with Pneumonia as Positive so that later we can check X-Ray Agent | |
| mask = (updated_patient_df['Gender'] == 'Female') & \ | |
| (updated_patient_df["Age"].between(20, 29)) & \ | |
| (updated_patient_df['Difficulty Breathing'] == 'Yes') & \ | |
| (updated_patient_df['Outcome Variable'] == 'Positive') | |
| selected_patients = updated_patient_df[mask].reset_index(drop=True) | |
| selected_patients.head() | |
| selected_patient = selected_patients.iloc[0] | |
| selected_patient | |
| # # Step 2: Create IDentity Photo for the Front Desk Agent | |
| # ## 2.1 Build the Vision Model for Gender Classification (Image Classification Task) | |
| # In[46]: | |
| # Use a pipeline as a high-level helper | |
| from transformers import pipeline | |
| pipe = pipeline("image-classification", model="rizvandwiki/gender-classification") | |
| # In[47]: | |
| # Load model directly | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| processor = AutoImageProcessor.from_pretrained("rizvandwiki/gender-classification") | |
| model = AutoModelForImageClassification.from_pretrained("rizvandwiki/gender-classification") | |
| # In machine learning, particularly in classification tasks, logits are the raw, unnormalized outputs produced by a model's final layer before any activation function is applied. These outputs represent the model's confidence scores for each class and are essential for subsequent probability calculations. | |
| # In[48]: | |
| from transformers import AutoModelForImageClassification, AutoProcessor | |
| from PIL import Image | |
| import requests | |
| # Load the model and processor | |
| model_name = "rizvandwiki/gender-classification" | |
| model = AutoModelForImageClassification.from_pretrained(model_name) | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| # Load the image from URL or local path | |
| image_url = "https://thispersondoesnotexist.com" | |
| image = Image.open(requests.get(image_url, stream=True).raw) | |
| # Prepare the image for the model | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Perform inference | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class = logits.argmax(-1).item() | |
| # Map prediction to class label | |
| classes = model.config.id2label | |
| gender_label = classes[predicted_class] | |
| print(f"Predicted Gender: {gender_label}") | |
| import matplotlib.pyplot as plt | |
| # Display the image and prediction | |
| plt.imshow(image) | |
| plt.axis('off') # Hide axes | |
| plt.title(f"Predicted Gender: {gender_label}") | |
| plt.show() | |
| # ## 2.2 Build the Vision Model for Age Classification (Image Classification Task) | |
| # Load age classification model | |
| age_model_name = "nateraw/vit-age-classifier" | |
| age_model = AutoModelForImageClassification.from_pretrained(age_model_name) | |
| age_processor = AutoProcessor.from_pretrained(age_model_name) | |
| # Age Prediction | |
| age_inputs = age_processor(images=image, return_tensors="pt") | |
| age_outputs = age_model(**age_inputs) | |
| age_logits = age_outputs.logits | |
| age_prediction = age_logits.argmax(-1).item() | |
| age_label = age_model.config.id2label[age_prediction] | |
| age_label | |
| # Display the image with both predictions | |
| plt.imshow(image) | |
| plt.axis('off') | |
| plt.title(f"Predicted Gender: {gender_label}, Predicted Age: {age_label}") | |
| plt.show() | |
| # # Step 3: Start Building Multi-Agents | |
| # | |
| # Define Each AI Agent | |
| # We'll define agents for: | |
| # | |
| # * Administration Front Desk | |
| # * Physician for General Health Examination + Blood Laboratory | |
| # * X-Ray Image Department | |
| # ## 3.1 Hospital Front Desk Agent | |
| # | |
| # | |
| # **--IMPORTANT NOTE--** <br> | |
| # 1. Don't forget to save one photo from https://thispersondoesnotexist.com/ | |
| # <br> as female.jpg and save it to this Path "/content/sample_data/' | |
| # <br> which is standard path within your Google Colab | |
| # | |
| # --- | |
| # 2. Don't Forget to Save one of the images from the x-ray-dataset <br>**Load Dataset in this way:** <br> | |
| # patient_x_ray_path = "keremberke/chest-xray-classification" <br> | |
| # x_ray_ds = load_dataset(patient_x_ray_path, name="full") | |
| # <br> Then save one image labelled as x-ray-chest.jpg to the path "/content/sample_data/' | |
| patient_x_ray_path = "keremberke/chest-xray-classification" | |
| x_ray_ds = load_dataset(patient_x_ray_path, name="full") | |
| from typing import List, Tuple, Dict, Any, Sequence, Annotated, Literal | |
| from typing_extensions import TypedDict | |
| from langchain_core.messages import BaseMessage | |
| import operator | |
| import functools | |
| from langchain_core.messages import HumanMessage | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph import END, START, StateGraph, MessagesState | |
| from langgraph.prebuilt import ToolNode, create_react_agent | |
| from langchain_core.tools import tool | |
| from transformers import AutoModelForImageClassification, AutoProcessor | |
| from PIL import Image | |
| from pydantic import BaseModel | |
| from langchain_core.output_parsers import StrOutputParser, JsonOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| # Annotated in python allows developers to declare the type of a reference and provide additional information related to it. | |
| # Literal, after that the value are exact and literal | |
| #----------------- Build Fucntions that Agents use ---------------------- | |
| def patient_verification_tool(image_Path, selected_patient_data, updated_patient_df) -> str: | |
| """Detects the gender from an image provided as a file path.""" | |
| from PIL import Image | |
| print(image_Path) | |
| model = AutoModelForImageClassification.from_pretrained("rizvandwiki/gender-classification") | |
| processor = AutoProcessor.from_pretrained("rizvandwiki/gender-classification") | |
| image = Image.open(image_Path) | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| predicted_class = outputs.logits.argmax(-1).item() | |
| print(f"Predicted Gender Of Patient is : {model.config.id2label[predicted_class]}") | |
| predicted_gender = model.config.id2label[predicted_class] | |
| from PIL import Image | |
| model = AutoModelForImageClassification.from_pretrained("nateraw/vit-age-classifier") | |
| processor = AutoProcessor.from_pretrained("nateraw/vit-age-classifier") | |
| image = Image.open(image_Path) | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| predicted_class = outputs.logits.argmax(-1).item() | |
| print(f"predicted Age Class: {model.config.id2label[predicted_class]}") | |
| predicted_age_range = model.config.id2label[predicted_class] | |
| # Parse the age range string (e.g., "20-29") | |
| age_min, age_max = map(int, predicted_age_range.split('-')) | |
| print(f"age_mi: {age_min}, age_max: {age_max}") | |
| # Verify against the DataFrame | |
| matching_row = updated_patient_df[ | |
| (updated_patient_df["First_Name"] == selected_patient["First_Name"]) & | |
| (updated_patient_df["Last_Name"] == selected_patient["Last_Name"]) & | |
| (updated_patient_df["Patient_ID"] == selected_patient["Patient_ID"]) & | |
| (updated_patient_df["Gender"].str.lower() == predicted_gender) & | |
| (updated_patient_df["Age"].between(age_min, age_max)) | |
| ] | |
| print(f"matching_row {matching_row} ") | |
| if not matching_row.empty: | |
| patient_verification = f'''Verification successful. | |
| Patient is : {selected_patient["First_Name"]} {selected_patient["Last_Name"]} | |
| with ID {selected_patient["Patient_ID"]} | |
| which is {predicted_gender} in age range of {predicted_age_range} can proceed to the physician.''' | |
| else: | |
| patient_verification = "ID not verified. Patient cannot proceed." | |
| return patient_verification | |
| #------------------- Define Agents----------------------------- | |
| class AgentState(TypedDict): | |
| initial_prompt : str | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| patient_verification : str | |
| def front_desk_agent(state, image_Path, selected_patient_data, updated_patient_df): | |
| initial_prompt = state["initial_prompt"] | |
| # Call function | |
| patient_verification = patient_verification_tool(image_Path, selected_patient_data, updated_patient_df) | |
| print(patient_verification) | |
| return {"patient_verification": patient_verification} | |
| #----------------------------------------------------------------- | |
| # Build the LangGraph for Hospital Front Desk # | |
| #----------------------------------------------------------------- | |
| image_Path = "female.jpg" | |
| selected_patient_data = selected_patient.to_dict() | |
| updated_patient_df | |
| front_desk_agent_node = functools.partial(front_desk_agent, | |
| image_Path = image_Path, | |
| selected_patient_data=selected_patient_data, | |
| updated_patient_df =updated_patient_df) | |
| # 6. Set up the Langgraph state graph | |
| FrontDeskGraph = StateGraph(AgentState) | |
| # Define nodes for workflow | |
| FrontDeskGraph.add_node("front_desk_agent", front_desk_agent_node) | |
| FrontDeskGraph.add_edge(START, "front_desk_agent") | |
| FrontDeskGraph.add_edge("front_desk_agent", END) | |
| # Initialize memory to persist state between graph runs | |
| FrontDeskWorkflow = FrontDeskGraph.compile() | |
| from IPython.display import Markdown, display, Image | |
| display(Image(FrontDeskWorkflow.get_graph(xray=True).draw_mermaid_png())) | |
| initial_prompt = "You are Front Desk Administrator in an Hospital in the Netherlands. Start Verification of the following Patient:" | |
| # Run the workflow | |
| inputs = {"initial_prompt" : initial_prompt | |
| } | |
| output = FrontDeskWorkflow.invoke(inputs) | |
| output | |
| display(Markdown(output['patient_verification'])) | |
| # ## 3.2 Pysician Agent | |
| def question_patient_symptoms(selected_patient_data) -> str: | |
| """Asks the patient about symptoms, generates responses, and summarizes the answers based on patient data.""" | |
| symptoms_questions = { | |
| "Cough": "\nAre you coughing?\n", | |
| "Fatigue": "\nDo you feel fatigue?\n", | |
| "\nDifficulty Breathing": "Do you have difficulty breathing?\n" | |
| } | |
| conversation = [] | |
| for symptom, question in symptoms_questions.items(): | |
| conversation.append(f"\nPhysician: {question}") | |
| response = selected_patient_data.get(symptom, "No") | |
| answer = "Yes" if response == "Yes" else "No" | |
| conversation.append(f"\nPatient: {answer}") | |
| first_name = selected_patient_data.get("First_Name", "") | |
| last_name = selected_patient_data.get("Last_Name", "") | |
| patient_id = selected_patient_data.get("Patient_ID", "") | |
| gender = selected_patient_data.get("Gender", "") | |
| age = selected_patient_data.get("Age", "") | |
| profile = f"\nYou are {first_name} {last_name}, a {age} years old {gender} with Patient ID: {patient_id}." | |
| summary = profile +"I gathered that you are experiencing the following: " | |
| summaries = [] | |
| for symptom in symptoms_questions.keys(): | |
| response = selected_patient_data.get(symptom, "No") | |
| if response == "Yes": | |
| summaries.append(f"you are experiencing {symptom.lower()}") | |
| else: | |
| summaries.append(f"\nI am glad you are not experiencing {symptom.lower()}") | |
| summary += "; ".join(summaries) + "." | |
| conversation.append(f"\nPhysician: {summary}") | |
| return "\n".join(conversation) | |
| def perform_examination(selected_patient_data) -> str: | |
| """Performs examination by reporting fever, blood pressure, and cholesterol level from patient data.""" | |
| fever = selected_patient_data.get("Fever", "Unknown") | |
| blood_pressure = selected_patient_data.get("Blood Pressure", "Unknown") | |
| cholesterol = selected_patient_data.get("Cholesterol Level", "Unknown") | |
| return f"Examination Results: Fever - {fever}, Blood Pressure - {blood_pressure}, Cholesterol Level - {cholesterol}" | |
| def diagnose_patient(selected_patient_data) -> str: | |
| """Provides diagnosis based on Disease and Outcome columns in patient data.""" | |
| disease = selected_patient_data.get("Disease", "Unknown Disease") | |
| outcome = selected_patient_data.get("Outcome Variable", "Unknown Outcome") | |
| if outcome == 'Positive': | |
| diagnosis = 'Make X-Ray from Chest' | |
| else: | |
| diagnosis = 'Rest to Recover' | |
| return f"Diagnosis: {disease}. Test Result: {outcome}. Final Diagnosis: {diagnosis}", diagnosis | |
| class AgentState(TypedDict): | |
| initial_prompt : str | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| question_patient_symptoms: str | |
| examination_patient: str | |
| diagnosis_patient: str | |
| diagnosis : str | |
| def physician_agent(state, selected_patient_data): | |
| question_patient= question_patient_symptoms(selected_patient_data) | |
| examination = perform_examination(selected_patient_data) | |
| diagnosis_report, diagnosis = diagnose_patient(selected_patient_data) | |
| return {"question_patient_symptoms": question_patient, | |
| "examination_patient": examination, | |
| "diagnosis_patient": diagnosis_report, | |
| "diagnosis": diagnosis} | |
| selected_patient_data = selected_patient.to_dict() | |
| physician_agent_node = functools.partial(physician_agent, | |
| selected_patient_data=selected_patient_data) | |
| # 6. Set up the Langgraph state graph | |
| PhysicianGraph = StateGraph(AgentState) | |
| # Define nodes for workflow | |
| PhysicianGraph.add_node("physician_agent", physician_agent_node) | |
| PhysicianGraph.add_edge(START, "physician_agent") | |
| PhysicianGraph.add_edge("physician_agent", END) | |
| # Initialize memory to persist state between graph runs | |
| PhysicianWorkflow = PhysicianGraph.compile() | |
| display(Image(PhysicianWorkflow.get_graph(xray=True).draw_mermaid_png())) | |
| initial_prompt = "You are a Very Experience Doctor in an Hospital in the Netherlands. Start a conversation with the patient and determine \ | |
| symptoms and give diagnosis" | |
| # Run the workflow | |
| inputs = {"initial_prompt" : initial_prompt | |
| } | |
| output = PhysicianWorkflow.invoke(inputs) | |
| output | |
| display(Markdown(output['question_patient_symptoms'])) | |
| display(Markdown(output['examination_patient'])) | |
| display(Markdown(output['diagnosis_patient'])) | |
| # ## 3.3 Radiologist | |
| def examine_X_ray_image(patient_x_ray_path) -> str: | |
| """Use Vision Models to recognise if the X-Ray Image of Patient is NORMAL or PNEUMONIA""" | |
| # Model in Hugging Face: https://huggingface.co/lxyuan/vit-xray-pneumonia-classification | |
| # vit-xray-pneumonia-classification | |
| x_ray_ds = load_dataset(patient_x_ray_path, name="full") | |
| random_index = random.randint(0, x_ray_ds['train'].shape[0] - 1) | |
| patient_x_ray_image = x_ray_ds['train'][random_index]['image'] | |
| classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification") | |
| patient_x_ray_results = classifier(patient_x_ray_image) | |
| # Find the label with the highest score and its score | |
| highest = max(patient_x_ray_results, key=lambda x: x['score']) | |
| highest_score_label = highest['label'] | |
| highest_score = highest['score'] * 100 # Convert to percentage | |
| # Choose the correct verb based on the label | |
| verb = "is" if highest_score_label == "NORMAL" else "has" | |
| return f"Patient {verb} {highest_score_label} with Probability of ca. {highest_score:.0f}%" | |
| class AgentState(TypedDict): | |
| initial_prompt : str | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| pneumonia_detection: str | |
| def radiologist_agent(state, patient_x_ray_path): | |
| pneumonia_detection = examine_X_ray_image(patient_x_ray_path) | |
| return {"pneumonia_detection": pneumonia_detection} | |
| patient_x_ray_path = "keremberke/chest-xray-classification" | |
| radiologist_agent_node = functools.partial(radiologist_agent, | |
| patient_x_ray_path=patient_x_ray_path) | |
| # 6. Set up the Langgraph state graph | |
| RadiologistGraph = StateGraph(AgentState) | |
| # Define nodes for workflow | |
| RadiologistGraph.add_node("radiologist_agent", radiologist_agent_node) | |
| RadiologistGraph.add_edge(START, "radiologist_agent") | |
| RadiologistGraph.add_edge("radiologist_agent", END) | |
| # Initialize memory to persist state between graph runs | |
| RadiologistWorkflow = RadiologistGraph.compile() | |
| display(Image(RadiologistWorkflow.get_graph(xray=True).draw_mermaid_png())) | |
| initial_prompt = "You are a Very Experienced Radiologist in an Hospital in the Netherlands. Diagnose if the patient has pneumonia" | |
| # Run the workflow | |
| inputs = {"initial_prompt" : initial_prompt | |
| } | |
| output = RadiologistWorkflow.invoke(inputs) | |
| output | |
| display(Markdown(output['pneumonia_detection'])) | |
| # # Step 4: Putting All Agents in One Graph | |
| from langchain_core.output_parsers import StrOutputParser, JsonOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| selected_patient_data = selected_patient.to_dict() | |
| image_Path = "female.jpg" | |
| patient_x_ray_image = patient_x_ray | |
| def patient_verification_tool(image_Path, selected_patient_data, updated_patient_df) -> str: | |
| """Detects the gender from an image provided as a file path.""" | |
| from PIL import Image | |
| print(image_Path) | |
| model = AutoModelForImageClassification.from_pretrained("rizvandwiki/gender-classification") | |
| processor = AutoProcessor.from_pretrained("rizvandwiki/gender-classification") | |
| image = Image.open(image_Path) | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| predicted_class = outputs.logits.argmax(-1).item() | |
| print(f"Predicted Gender Of Patient is : {model.config.id2label[predicted_class]}") | |
| predicted_gender = model.config.id2label[predicted_class] | |
| from PIL import Image | |
| model = AutoModelForImageClassification.from_pretrained("nateraw/vit-age-classifier") | |
| processor = AutoProcessor.from_pretrained("nateraw/vit-age-classifier") | |
| image = Image.open(image_Path) | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| predicted_class = outputs.logits.argmax(-1).item() | |
| print(f"predicted Age Class: {model.config.id2label[predicted_class]}") | |
| predicted_age_range = model.config.id2label[predicted_class] | |
| # Parse the age range string (e.g., "20-29") | |
| age_min, age_max = map(int, predicted_age_range.split('-')) | |
| print(f"age_mi: {age_min}, age_max: {age_max}") | |
| # Verify against the DataFrame | |
| matching_row = updated_patient_df[ | |
| (updated_patient_df["First_Name"] == selected_patient["First_Name"]) & | |
| (updated_patient_df["Last_Name"] == selected_patient["Last_Name"]) & | |
| (updated_patient_df["Patient_ID"] == selected_patient["Patient_ID"]) & | |
| (updated_patient_df["Gender"].str.lower() == predicted_gender) & | |
| (updated_patient_df["Age"].between(age_min, age_max)) | |
| ] | |
| print(f"matching_row {matching_row} ") | |
| if not matching_row.empty: | |
| patient_verification = f'''Verification successful. | |
| Patient is : {selected_patient["First_Name"]} {selected_patient["Last_Name"]} | |
| with ID {selected_patient["Patient_ID"]} | |
| which is {predicted_gender} in age range of {predicted_age_range} can proceed to the physician.''' | |
| else: | |
| patient_verification = "ID not verified. Patient cannot proceed." | |
| return patient_verification | |
| def question_patient_symptoms(selected_patient_data) -> str: | |
| """Asks the patient about symptoms, generates responses, and summarizes the answers based on patient data.""" | |
| symptoms_questions = { | |
| "Cough": "\nAre you coughing?\n", | |
| "Fatigue": "\nDo you feel fatigue?\n", | |
| "\nDifficulty Breathing": "Do you have difficulty breathing?\n" | |
| } | |
| conversation = [] | |
| for symptom, question in symptoms_questions.items(): | |
| conversation.append(f"\nPhysician: {question}") | |
| response = selected_patient_data.get(symptom, "No") | |
| answer = "Yes" if response == "Yes" else "No" | |
| conversation.append(f"\nPatient: {answer}") | |
| first_name = selected_patient_data.get("First_Name", "") | |
| last_name = selected_patient_data.get("Last_Name", "") | |
| patient_id = selected_patient_data.get("Patient_ID", "") | |
| gender = selected_patient_data.get("Gender", "") | |
| age = selected_patient_data.get("Age", "") | |
| profile = f"\nYou are {first_name} {last_name}, a {age} years old {gender} with Patient ID: {patient_id}." | |
| summary = profile +"I gathered that you are experiencing the following: " | |
| summaries = [] | |
| for symptom in symptoms_questions.keys(): | |
| response = selected_patient_data.get(symptom, "No") | |
| if response == "Yes": | |
| summaries.append(f"you are experiencing {symptom.lower()}") | |
| else: | |
| summaries.append(f"\nI am glad you are not experiencing {symptom.lower()}") | |
| summary += "; ".join(summaries) + "." | |
| conversation.append(f"\nPhysician: {summary}") | |
| return "\n".join(conversation) | |
| def perform_examination(selected_patient_data) -> str: | |
| """Performs examination by reporting fever, blood pressure, and cholesterol level from patient data.""" | |
| fever = selected_patient_data.get("Fever", "Unknown") | |
| blood_pressure = selected_patient_data.get("Blood Pressure", "Unknown") | |
| cholesterol = selected_patient_data.get("Cholesterol Level", "Unknown") | |
| return f"Examination Results: Fever - {fever}, Blood Pressure - {blood_pressure}, Cholesterol Level - {cholesterol}" | |
| def diagnose_patient(selected_patient_data) -> str: | |
| """Provides diagnosis based on Disease and Outcome columns in patient data.""" | |
| disease = selected_patient_data.get("Disease", "Unknown Disease") | |
| outcome = selected_patient_data.get("Outcome Variable", "Unknown Outcome") | |
| if outcome == 'Positive': | |
| diagnosis = 'Make X-Ray from Chest' | |
| else: | |
| diagnosis = 'Rest to Recover' | |
| return f"Diagnosis: {disease}. Test Result: {outcome}. Final Diagnosis: {diagnosis}", diagnosis | |
| def examine_X_ray_image(patient_x_ray_path) -> str: | |
| """Use Vision Models to recognise if the X-Ray Image of Patient is NORMAL or PNEUMONIA""" | |
| # Model in Hugging Face: https://huggingface.co/lxyuan/vit-xray-pneumonia-classification | |
| # vit-xray-pneumonia-classification | |
| x_ray_ds = load_dataset(patient_x_ray_path, name="full") | |
| random_index = random.randint(0, x_ray_ds['train'].shape[0] - 1) | |
| patient_x_ray_image = x_ray_ds['train'][random_index]['image'] | |
| classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification") | |
| patient_x_ray_results = classifier(patient_x_ray_image) | |
| # Find the label with the highest score and its score | |
| highest = max(patient_x_ray_results, key=lambda x: x['score']) | |
| highest_score_label = highest['label'] | |
| highest_score = highest['score'] * 100 # Convert to percentage | |
| # Choose the correct verb based on the label | |
| verb = "is" if highest_score_label == "NORMAL" else "has" | |
| return f"Patient {verb} {highest_score_label} with Probability of ca. {highest_score:.0f}%" | |
| # The agent state is the input to each node in the graph | |
| class AgentState(TypedDict): | |
| # The annotation tells the graph that new messages will always | |
| # be added to the current states | |
| initial_prompt : str | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| patient_verification : str | |
| question_patient_symptoms: str | |
| examination_patient: str | |
| diagnosis_patient: str | |
| diagnosis : str | |
| pneumonia_detection: str | |
| def front_desk_agent(state, image_Path, selected_patient_data, updated_patient_df): | |
| initial_prompt = state["initial_prompt"] | |
| patient_verification = patient_verification_tool(image_Path, selected_patient_data, updated_patient_df) | |
| print(patient_verification) | |
| return {"patient_verification": patient_verification} | |
| def physician_agent(state, selected_patient_data): | |
| question_patient= question_patient_symptoms(selected_patient_data) | |
| examination = perform_examination(selected_patient_data) | |
| diagnosis_report, diagnosis = diagnose_patient(selected_patient_data) | |
| pneumonia_detection = examine_X_ray_image(patient_x_ray_path) | |
| return {"question_patient_symptoms": question_patient, | |
| "examination_patient": examination, | |
| "diagnosis_patient": diagnosis_report, | |
| "diagnosis": diagnosis} | |
| def radiologist_agent(state, patient_x_ray_path): | |
| pneumonia_detection = examine_X_ray_image(patient_x_ray_path) | |
| return {"pneumonia_detection": pneumonia_detection} | |
| def decide_on_radiologist(state): | |
| if state["diagnosis"] == 'Make X-Ray from Chest': | |
| return 'radiologist' | |
| else: | |
| return '' | |
| image_Path = "female.jpg" | |
| selected_patient_data = selected_patient.to_dict() | |
| updated_patient_df | |
| patient_x_ray_path = "keremberke/chest-xray-classification" | |
| front_desk_agent_node = functools.partial(front_desk_agent, | |
| image_Path = image_Path, | |
| selected_patient_data=selected_patient_data, | |
| updated_patient_df =updated_patient_df) | |
| physician_agent_node = functools.partial(physician_agent, | |
| selected_patient_data=selected_patient_data) | |
| radiologist_agent_node = functools.partial(radiologist_agent, | |
| patient_x_ray_path=patient_x_ray_path) | |
| def decide_on_radiologist(state): | |
| if state["diagnosis"] == 'Make X-Ray from Chest': | |
| return 'radiologist' | |
| else: | |
| return 'end' | |
| # 6. Set up the Langgraph state graph | |
| HospitalGraph = StateGraph(AgentState) | |
| # Define nodes for workflow | |
| HospitalGraph.add_node("front_desk_agent", front_desk_agent_node) | |
| HospitalGraph.add_node("physician_agent", physician_agent_node) | |
| HospitalGraph.add_node("radiologist_agent", radiologist_agent_node) | |
| HospitalGraph.add_edge(START, "front_desk_agent") | |
| HospitalGraph.add_edge("front_desk_agent", "physician_agent") | |
| HospitalGraph.add_conditional_edges("physician_agent", | |
| decide_on_radiologist, | |
| {'radiologist': "radiologist_agent", | |
| 'end': END}) | |
| # Initialize memory to persist state between graph runs | |
| HospitalWorkflow = HospitalGraph.compile() | |
| display(Image(HospitalWorkflow.get_graph(xray=True).draw_mermaid_png())) | |
| initial_prompt = "Start with the following Patient" | |
| # Run the workflow | |
| inputs = {"initial_prompt" : initial_prompt | |
| } | |
| output = HospitalWorkflow.invoke(inputs) | |
| output | |
| display(Markdown(output['patient_verification'])) | |
| display(Markdown(output['question_patient_symptoms'])) | |
| display(Markdown(output['examination_patient'])) | |
| display(Markdown(output['diagnosis_patient'])) | |
| display(Markdown(output['pneumonia_detection'])) | |
| # # Step 5: Gradio Dashboard | |
| # ## 5.1 Build the Hospital Dashboard APP | |
| # In[69]: | |
| x_ray_image_path = 'x-ray-chest.png' | |
| import gradio as gr | |
| info = ( | |
| f"**First Name:** {selected_patient_data['First_Name']}\n\n" | |
| f"**Last Name:** {selected_patient_data['Last_Name']}\n\n" | |
| f"**Patient ID:** {selected_patient_data['Patient_ID']}" | |
| ) | |
| def verify_age_gender(): | |
| """ | |
| Function to verify age and gender. | |
| """ | |
| # Placeholder logic: In a real scenario, perform necessary checks or computations | |
| initial_prompt = "You are Front Desk Administrator in an Hospital in the Netherlands. Start Verification of the following Patient:" | |
| inputs = {"initial_prompt" : initial_prompt | |
| } | |
| output = FrontDeskWorkflow.invoke(inputs) | |
| verification_message = 'โ ' + output['patient_verification'] | |
| return verification_message, gr.update(visible=True) | |
| def physician_examination(): | |
| initial_prompt = "You are a Very Experience Doctor in an Hospital in the Netherlands. Start a conversation with the patient and determine \ | |
| symptoms and give diagnosis" | |
| # Run the workflow | |
| inputs = {"initial_prompt" : initial_prompt | |
| } | |
| output = PhysicianWorkflow.invoke(inputs) | |
| output_all = f''' ๐ฉบ {output['question_patient_symptoms']}\n | |
| ๐ {output['examination_patient']}\n | |
| ๐ฌ๏ธ {output['diagnosis_patient']}''' | |
| return output_all, gr.update(visible=True) | |
| def pneumonia_detection(): | |
| initial_prompt = "You are a Very Experienced Radiologist in an Hospital in the Netherlands. Diagnose if the patient has pneumonia" | |
| inputs = {"initial_prompt" : initial_prompt | |
| } | |
| output = RadiologistWorkflow.invoke(inputs) | |
| pneumonia_detection = 'From X-Ray Image ๐ผ๏ธ ' + output['pneumonia_detection'] | |
| return pneumonia_detection | |
| def take_xray_image(): | |
| return gr.update(visible=True), gr.update(visible=True) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown(info) | |
| # Add a Button below the Markdown | |
| verify_button = gr.Button("Verify Age and Gender") | |
| # Add an output component to display verification status | |
| verification_output = gr.Textbox(label="Verification Status", interactive=False, lines=5, max_lines=None) | |
| # Add a Button below the Markdown | |
| physician_button = gr.Button("Get Examination at Physician", visible=False) | |
| physician_output = gr.Textbox(label="Examination by Physician Placeholder", interactive=False, lines=35, max_lines=None) | |
| x_ray_button = gr.Button("Take Chest X-Ray Image", visible=False) | |
| # Display X-Ray Image (Initially Hidden) | |
| xray_image_display = gr.Image(value=x_ray_image_path, label="X-Ray Image", visible=False) | |
| radiologist_button = gr.Button("Go to Radiologist", visible=False) | |
| # Add an output component to display verification status | |
| radiologist_output = gr.Textbox(label="Radiologist Placeholder", interactive=False, lines=5, max_lines=None) | |
| with gr.Column(scale=1): | |
| gr.Image(value=image_Path, label="Static Image", show_label=True) | |
| # Define the button's action: When clicked, call verify_age_gender and display the result | |
| verify_button.click(fn=verify_age_gender, inputs=None, outputs=[verification_output, physician_button]) | |
| physician_button.click(fn=physician_examination, inputs=None, outputs=[physician_output, x_ray_button]) | |
| x_ray_button.click(fn=take_xray_image, inputs=None, outputs=[xray_image_display, radiologist_button]) | |
| radiologist_button.click(fn=pneumonia_detection, inputs=None, outputs=[radiologist_output]) | |
| # ## 5.2 Run the App | |
| # Launch the app | |
| #demo.launch(share=True, debug=False) | |
| #demo.launch(share=True, debug=False, allowed_paths=[dataDir], ssr_mode=False) | |
| demo.launch(share=True, debug=False, ssr_mode=False) | |