Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from st_audiorec import st_audiorec | |
| from Modules.Speech2Text.transcribe import transcribe | |
| import base64 | |
| from langchain_mistralai import ChatMistralAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| import pandas as pd | |
| import json | |
| from dotenv import load_dotenv | |
| load_dotenv() # load .env api keys | |
| import os | |
| import shutil | |
| from Modules.rag import rag_chain | |
| from Modules.router import router_chain | |
| from Modules.workout_plan import workout_chain | |
| from Modules.PoseEstimation.pose_agent import agent_executor | |
| mistral_api_key = os.getenv("MISTRAL_API_KEY") | |
| from Modules.PoseEstimation import pose_estimator | |
| from utils import save_uploaded_file, encode_video_H264 | |
| st.set_page_config(layout="wide", initial_sidebar_state="collapsed") | |
| st.markdown("# Welcome to the AI Bros playground! ποΈββοΈπ€ποΈββοΈ") | |
| # Create two columns | |
| col1, col2 = st.columns(2) | |
| video_uploaded = None | |
| llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0) | |
| prompt = ChatPromptTemplate.from_template( | |
| template =""" You are a personal coach with an expertise in nutrition and fitness. | |
| You are having a conversation with your client, which is either a beginner or an advanced athlete. | |
| You are motivative and have very good vibe. | |
| Call your client buddy or sweatheart. | |
| You are a human, not a robot. | |
| Always try to answer concisely to the queries. | |
| User: {question} | |
| AI Coach:""" | |
| ) | |
| base_chain = prompt | llm | |
| display_workout = False | |
| # First column containers | |
| with col1: | |
| st.subheader("Coach chat") | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "file_name" not in st.session_state: | |
| st.session_state.file_name = None | |
| # for message in st.session_state.messages: | |
| # with st.chat_message(message["role"]): | |
| # st.markdown(message["content"]) | |
| if prompt := st.chat_input("What's up?"): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant", avatar="data/AI_Bro.png"): | |
| # Build answer from LLM | |
| direction = router_chain.invoke({"question":prompt}) | |
| print(type(direction)) | |
| print(direction) | |
| if direction=='fitness_advices': | |
| with st.spinner("Retrieving relevant data..."): | |
| response = rag_chain.invoke( | |
| prompt | |
| ) | |
| elif direction=='smalltalk': | |
| with st.spinner("Thinking..."): | |
| response = base_chain.invoke( | |
| {"question":prompt} | |
| ).content | |
| elif direction =='movement_analysis': | |
| if st.session_state.file_name is not None: | |
| prompt += "the file name is " + st.session_state.file_name | |
| with st.spinner("Analyzing movement..."): | |
| response = agent_executor.invoke( | |
| {"input" : prompt} | |
| )["output"] | |
| else: | |
| with st.spinner("Creating workout program..."): | |
| response = "Sure! I just made a workout for you. Check on the table I just provided you." | |
| json_output = workout_chain.invoke({"query":prompt}) | |
| exercises_list = json_output['exercises'] | |
| workout_df = pd.DataFrame(exercises_list) | |
| workout_df.columns = ["exercice", "sets", "reps", "rest"] | |
| display_workout=True | |
| print(type(response)) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| st.markdown(response) | |
| if display_workout: | |
| st.subheader("Workout") | |
| st.data_editor(workout_df) | |
| # Second column containers | |
| with col2: | |
| # st.subheader("Sports Agenda") | |
| # TO DO | |
| st.subheader("Technique Analysis") | |
| video_uploaded = st.file_uploader("Choose a video file", type=["mp4", "ogg", "webm", "MOV"]) | |
| if video_uploaded: | |
| video_uploaded = save_uploaded_file(video_uploaded) | |
| if video_uploaded.split("/")[-1] != st.session_state.file_name: | |
| shutil.rmtree('fig', ignore_errors=True) | |
| shutil.rmtree('/home/user/.pyenv/runs', ignore_errors=True) | |
| st.session_state.file_name = None | |
| st.session_state.file_name = video_uploaded.split("/")[-1] | |
| _left, mid, _right = st.columns([1, 2, 1]) | |
| with mid: | |
| if os.path.exists('/home/user/.pyenv/runs'): | |
| predict_list = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose')) | |
| predict_list.sort() | |
| predict_dir = predict_list[-1] | |
| file_name = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir))[0] | |
| file_path =os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir, file_name) | |
| file_path = encode_video_H264(file_path, remove_original=True) | |
| st.video(file_path, loop=True) | |
| else : | |
| st.video(video_uploaded) | |
| if os.path.exists('fig'): | |
| st.subheader("Graph Displayer") | |
| file_list = os.listdir('fig') | |
| for file in file_list: | |
| st.image(os.path.join('fig', file)) |