Spaces:
Sleeping
Sleeping
| from langchain.tools import tool | |
| from langchain.agents import AgentExecutor, create_tool_calling_agent | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.messages import HumanMessage | |
| from langchain_mistralai.chat_models import ChatMistralAI | |
| import torch | |
| import os | |
| import sys | |
| import json | |
| sys.path.append(os.getcwd()) | |
| from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average, save_knee_angle_fig | |
| # If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable. | |
| llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4") | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def get_keypoints_from_keypoints(video_path: str) -> str: | |
| """ | |
| Extracts keypoints from a video file. | |
| Args: | |
| video_path (str): path to the video file | |
| Returns: | |
| file_path (str): path to the JSON file containing the keypoints | |
| """ | |
| save_folder='tmp' | |
| os.makedirs(save_folder, exist_ok=True) | |
| keypoints = [] | |
| results = model(video_path, save=True, show_conf=False, show_boxes=False, device=device) | |
| for (i, frame) in enumerate(results): | |
| frame_dict = {} | |
| frame_dict['frame'] = i | |
| frame_dict['keypoints'] = frame.keypoints.xy[0].tolist() | |
| keypoints.append(frame_dict) | |
| file_path = os.path.join(save_folder, 'keypoints.json') | |
| with open(file_path, 'w') as f: | |
| json.dump(keypoints, f) | |
| return file_path | |
| def compute_right_knee_angle_list(json_path: str) -> list[float]: | |
| """ | |
| Computes the knee angle from a list of keypoints. | |
| Args: | |
| json_path (str): path to the JSON file containing the keypoints | |
| Returns: | |
| right_knee_angle_list (list[float]): list of knee angles | |
| """ | |
| keypoints_list = json.load(open(json_path)) | |
| right_knee_angle_list = [] | |
| for keypoints in keypoints_list: | |
| right_knee_angle = compute_right_knee_angle(keypoints['keypoints']) | |
| right_knee_angle_list.append(right_knee_angle) | |
| right_knee_angle_list = moving_average(right_knee_angle_list, 10) | |
| save_knee_angle_fig(right_knee_angle_list) | |
| return right_knee_angle_list | |
| def check_knee_angle(json_path: str) -> bool: | |
| """ | |
| Checks if the minimum knee angle is smaller than a threshold. | |
| If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough. | |
| Args: | |
| json_path (str): path to the JSON file containing the keypoints | |
| Returns: | |
| is_correct (bool): True if the minimum knee angle is smaller than a threshold, False otherwise | |
| """ | |
| angles_list = compute_right_knee_angle_list(json_path) | |
| for angle in angles_list: | |
| if angle < 90: | |
| return True | |
| return False | |
| def check_squat(file_name: str) -> str: | |
| """ | |
| Checks if the squat is correct. | |
| This function uses the check_knee_angle tool to check if the squat is correct. it checks if the user is going deep enough. | |
| Args: | |
| video_path (str): path to the video file | |
| Returns: | |
| is_correct (bool): True if the squat is correct, False otherwise | |
| """ | |
| video_path = os.path.join('uploaded', file_name) | |
| if os.path.exists(video_path): | |
| json_path = get_keypoints_from_keypoints(video_path) | |
| is_correct = check_knee_angle(json_path) | |
| if is_correct: | |
| return "The squat is correct because your knee angle is smaller than 90 degrees." | |
| else: | |
| return "The squat is incorrect because your knee angle is greater than 90 degrees." | |
| else: | |
| return "The video file does not exist." | |
| tools = [check_squat] | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| "You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response, giving advices", | |
| ), | |
| ("placeholder", "{chat_history}"), | |
| ("human", "{input}"), | |
| ("placeholder", "{agent_scratchpad}"), | |
| ] | |
| ) | |
| # Construct the Tools agent | |
| agent = create_tool_calling_agent(llm, tools, prompt) | |
| agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) |