canary_aed_streaming / gpu_compute.py
Archime's picture
impl fastrtc receive
75c9c9a
raw
history blame
1.68 kB
from app.logger_config import logger as logging
from app.utils import (
debug_current_device,
get_current_device
)
import os
import gradio as gr
import spaces
import torch
@spaces.GPU
def gpu_compute(name):
logging.debug("=== Start of gpu_compute() ===")
debug_current_device()
tensor,device_name = compute(name)
logging.debug("=== End of gpu_compute() ===")
return f"Tensor: {tensor.cpu().numpy()} | Device: {device_name}"
def cpu_compute(name):
logging.debug("=== Start of cpu_compute() ===")
debug_current_device()
tensor,device_name = compute(name)
logging.debug("=== End of cpu_compute() ===")
return f"Tensor: {tensor.cpu().numpy()} | Device: {device_name}"
def compute(name) :
# Get device info
device, device_name = get_current_device()
# Create a tensor
tensor = torch.tensor([len(name)], dtype=torch.float32, device=device)
logging.debug(f"Tensor created: {tensor}")
# Optional: free GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
logging.debug("GPU cache cleared")
return tensor, device_name
block = gr.Blocks()
with block as demo:
with gr.Row():
input_text = gr.Text()
output_text = gr.Text()
with gr.Row():
gpu_button = gr.Button("GPU compute")
cpu_button = gr.Button("CPU compute")
gpu_button.click(fn=gpu_compute, inputs=[input_text],outputs=[output_text])
cpu_button.click(fn=cpu_compute, inputs=[input_text],outputs=[output_text])
with gr.Blocks() as demo:
block.render()
if __name__ == "__main__":
demo.queue(max_size=10, api_open=False).launch(show_api=False)