Spaces:
Running
Running
| import time | |
| from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse | |
| from tools.server.agent.generation_utils import ( | |
| initialize_decode_buffers, | |
| process_response_tokens, | |
| send_reset_buffer, | |
| ) | |
| from tools.server.agent.pre_generation_utils import ( | |
| create_generation_request, | |
| send_generation_request, | |
| ) | |
| def generate_responses( | |
| input_queue, tokenizer, config, request, prompt, im_end_id, device | |
| ): | |
| """ | |
| Main generation function that handles the conversation, encodes the request, | |
| sends the generation request, and handles decoding/streaming. | |
| It returns a response generator (ServeResponse or ServeStreamResponse). | |
| """ | |
| stats = {} | |
| start = time.time() | |
| stats["start_time"] = start | |
| stats["tokens_count"] = 0 | |
| # Prepare and send the generation request | |
| req = create_generation_request(prompt, request, im_end_id, device) | |
| response_queue = send_generation_request(input_queue, req) | |
| decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples) | |
| while True: | |
| response = response_queue.get() | |
| # Handle abnormal finish or error | |
| if response in ["stop", "error"]: | |
| finish_reason = response | |
| break | |
| # Process the response tokens | |
| is_first_token = stats["tokens_count"] == 0 | |
| responses = process_response_tokens( | |
| response, | |
| tokenizer, | |
| config, | |
| request, | |
| decode_buffer, | |
| parts, | |
| finished, | |
| im_end_id, | |
| stats, | |
| start, | |
| is_first_token, | |
| ) | |
| # Yield the responses if streaming | |
| if request.streaming and responses: | |
| for r in responses: | |
| yield r | |
| stats["tokens_count"] += 1 | |
| # Check if all samples are finished | |
| if all(finished): | |
| finish_reason = "stop" | |
| break | |
| # Finalize the response | |
| final_responses = finalize_response( | |
| request, finished, decode_buffer, tokenizer, parts, stats, finish_reason | |
| ) | |
| for fr in final_responses: | |
| yield fr | |
| def finalize_response( | |
| request, finished, decode_buffer, tokenizer, parts, stats, finish_reason | |
| ): | |
| """ | |
| Finalize the response by sending the remaining text buffers. | |
| """ | |
| responses = [] | |
| # Send the remaining text buffers | |
| for sample_id in range(request.num_samples): | |
| responses.extend( | |
| send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) | |
| ) | |
| # Calculate the final stats | |
| stats["total_time"] = (time.time() - stats["start_time"]) * 1000 | |
| stats["total_tokens"] = stats["tokens_count"] | |
| # If streaming, send the final chunks for each sample | |
| if request.streaming: | |
| for sample_id in range(request.num_samples): | |
| if finished[sample_id]: | |
| continue | |
| responses.append( | |
| ServeStreamResponse( | |
| finish_reason=finish_reason, stats=stats, sample_id=sample_id | |
| ) | |
| ) | |
| else: | |
| # If not streaming, send the full messages for each sample | |
| full_messages = [ | |
| ServeMessage(role="assistant", parts=parts[i]) | |
| for i in range(request.num_samples) | |
| ] | |
| responses.append( | |
| ServeResponse( | |
| messages=full_messages, | |
| finish_reason=finish_reason, | |
| stats=stats, | |
| ) | |
| ) | |
| return responses | |