File size: 3,738 Bytes
b5bcf5a
 
 
76e1435
 
 
 
 
b5bcf5a
76e1435
b5bcf5a
 
 
6a31985
b5bcf5a
 
6a31985
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45a53c4
76e1435
67d411a
eab0adb
45a53c4
 
 
76e1435
45a53c4
 
 
67d411a
76e1435
 
45a53c4
76e1435
45a53c4
 
 
 
 
 
 
 
76e1435
 
45a53c4
76e1435
 
e4b23f9
76e1435
 
 
 
 
 
 
 
 
 
e4b23f9
 
 
ea5eb99
 
e4b23f9
ea5eb99
 
 
 
 
e4b23f9
 
 
 
ea5eb99
 
 
 
e4b23f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from huggingface_hub import login
from collections.abc import Iterator
from transformers import (
    Gemma3ForConditionalGeneration,
    TextIteratorStreamer,
    Gemma3Processor,
)
import spaces
import tempfile
from threading import Thread
import gradio as gr
import os
from dotenv import load_dotenv, find_dotenv
import cv2
from loguru import logger
from PIL import Image

dotenv_path = find_dotenv()

load_dotenv(dotenv_path)

model_id = os.getenv("MODEL_ID", "google/gemma-3-4b-it")

input_processor = Gemma3Processor.from_pretrained(model_id)

model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",
)


def get_frames(video_path: str, max_images: int) -> list[tuple[Image.Image, float]]:
    frames: list[tuple[Image.Image, float]] = []
    capture = cv2.VideoCapture(video_path)
    if not capture.isOpened():
        raise ValueError(f"Could not open video file: {video_path}")

    fps = capture.get(cv2.CAP_PROP_FPS)
    total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))

    frame_interval = max(total_frames // max_images, 1)
    max_position = min(total_frames, max_images * frame_interval)
    i = 0

    while i < max_position and len(frames) < max_images:
        capture.set(cv2.CAP_PROP_POS_FRAMES, i)
        success, image = capture.read()
        if success:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(image)
            timestamp = round(i / fps, 2)
            frames.append((pil_image, timestamp))

        i += frame_interval

    capture.release()
    return frames


def process_video(video_path: str, max_images: int) -> list[dict]:
    result_content = []
    # TODO: Change max_image to slider
    frames = get_frames(video_path, max_images)
    for frame in frames:
        image, timestamp = frame
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
            image.save(temp_file.name)
            result_content.append({"type": "text", "text": f"Frame {timestamp}:"})
            result_content.append({"type": "image", "url": temp_file.name})
    logger.debug(
        f"Processed {len(frames)} frames from video {video_path} with frames {result_content}"
    )
    return result_content


def process_user_input(message: dict, max_images: int) -> list[dict]:
    if not message["files"]:
        return [{"type": "text", "text": message["text"]}]

    if message["files"][0].endswith(".mp4"):
        return [
            {"type": "text", "text": message["text"]},
            *process_video(message["files"][0], max_images),
        ]

    return [
        {"type": "text", "text": message["text"]},
        *[{"type": "image", "url": path} for path in message["files"]],
    ]


def process_history(history: list[dict]) -> list[dict]:
    messages = []
    user_content_buffer = []

    for item in history:
        if item["role"] == "assistant":
            if user_content_buffer:
                messages.append({"role": "user", "content": user_content_buffer})
                user_content_buffer = []

            messages.append(
                {
                    "role": "assistant",
                    "content": [{"type": "text", "text": item["content"]}],
                }
            )
        else:
            content = item["content"]
            user_content_buffer.append(
                {"type": "text", "text": content}
                if isinstance(content, str)
                else {"type": "image", "url": content[0]}
            )

    if user_content_buffer:
        messages.append({"role": "user", "content": user_content_buffer})

    return messages