import uvicorn from fastapi import FastAPI, UploadFile, File, Form, HTTPException from typing import List, Dict, Any, Optional import json import os from transformers import AutoProcessor, AutoModel import torch, torchaudio import os import copy from rapidfuzz import process, fuzz from pypinyin import pinyin, Style def correct_sentence_with_pinyin(user_input_sentence, location_dict, score_cutoff=50): pinyin_dict = {} for location in location_dict: pinyin_name = ''.join([item[0] for item in pinyin(location, style=Style.NORMAL)]) pinyin_dict[pinyin_name] = location user_pinyin_sentence = ''.join([item[0] for item in pinyin(user_input_sentence, style=Style.NORMAL)]) best_match_pinyin = process.extractOne( query=user_pinyin_sentence, choices=list(pinyin_dict.keys()), # 傳入拼音作為搜尋目標 scorer=fuzz.token_set_ratio, score_cutoff=score_cutoff ) if best_match_pinyin: best_pinyin_name = best_match_pinyin[0] corrected_location_name = pinyin_dict[best_pinyin_name] best_user_substring = None max_substring_score = 0 for i in range(len(user_input_sentence)): for j in range(i + 2, min(i + 16, len(user_input_sentence) + 1)): substring = user_input_sentence[i:j] score = fuzz.ratio(substring, corrected_location_name) if score > max_substring_score: max_substring_score = score best_user_substring = substring if best_user_substring and max_substring_score > score_cutoff: return user_input_sentence.replace(best_user_substring, corrected_location_name, 1) else: return user_input_sentence return user_input_sentence class InferenceClass: def __init__(self,model_id): self.model = AutoModel.from_pretrained( model_id, device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True, attn_implementation="eager" ).eval() self.processor = AutoProcessor.from_pretrained( model_id, trust_remote_code=True ) self.remove_words_signs = lambda x:x.replace('User transcribe is :','').replace('GPT output is :','').replace('\n','').\ replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\ replace('!','') def call_gpt(self,inputs_tensor): with torch.inference_mode(): inputs = {k:inputs_tensor[k].to('cuda') for k in inputs_tensor} generate_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False) generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :] model_output = self.processor.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return model_output def call_function_fake(self,messages=[],obs=""): messages.append({'from': 'observation', 'value': obs}) return messages def generate(self,chat_history,tools="",audio_path=None): ''' input: audio_path : str chat_history : dict return: model_output : dict ''' chat_history = copy.deepcopy(chat_history) if type(audio_path)!=type(None): chat_history.append({'from': 'human', 'value': [{'type': 'audio', 'audio': audio_path}]}) words_from_poi = [] for hist in chat_history: if hist['from']=='observation' and '地點查詢成功' in hist['value'] and 'poi' in hist['value']: tmp = json.loads(hist['value']) for i,poi in enumerate(tmp['poi']): words_from_poi.append(poi['name']) for hist in chat_history: if hist['from']=='human' and type(hist['value'])==str: hist['value'] = correct_sentence_with_pinyin(hist['value'],words_from_poi) elif hist['from']=='function_call' and "arguments" in hist['value'] and 'keyword' in hist['value']["arguments"]: hist['value']["arguments"] = eval(hist['value']["arguments"]) if 'keyword' in hist['value']["arguments"]: hist['value']["arguments"]['keyword'] = correct_sentence_with_pinyin(hist['value']["arguments"]['keyword'],words_from_poi) hist['value']["arguments"] = str(hist['value']["arguments"]) # model_input_history = copy.deepcopy(chat_history) # num2ch = {1:'一',2:'二',3:'三',4:'四',5:'五',6:'六'} # for hist in model_input_history: # if hist['from']=='observation' and '地點查詢成功' in hist['value'] and 'poi' in hist['value']: # tmp = json.loads(hist['value']) # new_poi = [] # for i,poi in enumerate(tmp['poi']): # new_poi.append('第{}個 : '.format(num2ch[i+1])+str(poi)) # tmp['poi'] = new_poi # hist['value'] = json.dumps(tmp, ensure_ascii=False) inputs_text = self.processor.apply_chat_template( chat_history, add_generation_prompt=True, tokenize=False, return_dict=True, return_tensors="pt", tools=json.loads(tools) ) inputs_tensor = self.processor(text=inputs_text, audio=[torchaudio.load(audio_path)[0]] if type(audio_path)!=type(None) else None, add_special_tokens=False, return_tensors='pt' ) model_output = self.call_gpt(inputs_tensor) if chat_history[-1]['from']=='observation': chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(model_output,words_from_poi)}) return chat_history if ((not ';\n' in model_output) or (not 'User transcribe is :' in model_output) or (not 'GPT output is :' in model_output)\ or len(model_output.split(';\n'))<2 ): if chat_history[-1]['value']!="抱歉我聽不清楚 能麻煩您再說一次嗎": chat_history.append({'from': 'human', 'value': 'HUMAN_VOICE_IS_NOT_RECOGNIZED'}), chat_history.append({'from': 'gpt', 'value': '抱歉我聽不清楚 能麻煩您再說一次嗎'}) return chat_history output_t,output_o = model_output.split(';\n')[:2] output_t,output_o = self.remove_words_signs(output_t),self.remove_words_signs(output_o) chat_history[-1]['value'] = correct_sentence_with_pinyin(output_t,words_from_poi) if 'Action:' in output_o and 'ActionInput:' in output_o: # function calling function_name,function_arg = output_o.split('ActionInput:') function_name = function_name.replace('Action:','') if "keyword" in function_arg: function_arg = json.loads(function_arg) if "keyword" in function_arg: function_arg["keyword"] = correct_sentence_with_pinyin(function_arg["keyword"],words_from_poi) chat_history.append({'from': 'function_call', 'value': {"name": function_name, "arguments": str(function_arg)}}) else: # gpt response chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(output_o,words_from_poi)}) return chat_history model_id = "/home/jeff/jeff/codes/llm/InCar/gemma-3-4b-it-omni" pipeline = InferenceClass(model_id) app = FastAPI( title="Audio LLM API", description="An API that accepts an audio file and a list of dictionaries.", ) import json dataset = json.load(open('/home/jeff/jeff/codes/llm/InCar/data/test_data/nav_0730_noisy.json')) tools = dataset[0]['tools'] @app.post("/audio_llm/") async def process_audio_and_data( audio_file: Optional[UploadFile] = File(None, description="The audio file to be processed."), data: str = Form(..., description="A JSON string representing a list of chat history dictionaries.") ) -> List[Dict[str, Any]]: try: input_data_list = json.loads(data) if not isinstance(input_data_list, list) or not all(isinstance(item, dict) for item in input_data_list): raise ValueError("The provided data is not a list of dictionaries.") except json.JSONDecodeError: raise HTTPException( status_code=422, detail="Invalid JSON format for 'data' field. Please provide a valid JSON string." ) except ValueError as e: raise HTTPException( status_code=422, detail=str(e) ) temp_file_path=None if audio_file: temp_file_path = f"./audio_path/temp_{audio_file.filename}" with open(temp_file_path, "wb") as buffer: buffer.write(await audio_file.read()) print(f"Audio file saved to {temp_file_path}") output_data = pipeline.generate(input_data_list,tools=tools,audio_path=temp_file_path) print(output_data) return output_data # uvicorn main:app --host 0.0.0.0 --port 8087 --log-level info --workers 1 >> ./log.txt if __name__ == "__main__": uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)