File size: 10,305 Bytes
b931367
 
 
 
 
 
 
 
c383152
b931367
c383152
 
 
 
 
 
 
 
 
 
 
 
 
b931367
 
c383152
b931367
 
 
c383152
439ab17
 
 
c383152
b931367
 
c383152
b931367
c383152
b931367
 
 
c383152
b931367
c383152
 
b931367
 
 
 
 
c383152
439ab17
b931367
 
 
c383152
b931367
 
c383152
b931367
 
 
 
 
 
 
c383152
b931367
c383152
 
b931367
c383152
b931367
 
 
c383152
 
b931367
 
 
 
 
 
 
 
c383152
 
b931367
 
c383152
b931367
 
 
 
c383152
b931367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c383152
b931367
 
 
c383152
 
 
b931367
 
 
 
 
 
 
 
c383152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b931367
 
 
 
 
 
 
c383152
b931367
 
 
 
 
 
 
 
c383152
b931367
 
 
c383152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr
import uuid
from datetime import datetime
import pandas as pd
from model_handler import ModelHandler
from config import CHAT_MODEL_SPECS, LING_1T
from recommand_config import RECOMMENDED_INPUTS
from ui_components.model_selector import create_model_selector
from i18n import get_text

def get_history_df(history):
    if not history:
        # Provide explicit column names for an empty DataFrame
        return pd.DataFrame({'ID': pd.Series(dtype='str'), '对话': pd.Series(dtype='str')})
    df = pd.DataFrame(history)
    # Ensure columns exist before renaming
    if 'id' in df.columns and 'title' in df.columns:
        return df[['id', 'title']].rename(columns={'id': 'ID', '对话': '对话'})
    else:
        return pd.DataFrame({'ID': pd.Series(dtype='str'), '对话': pd.Series(dtype='str')})


def create_chat_tab(initial_lang: str, current_lang_state: gr.State):
    model_handler = ModelHandler()

    # Browser-side storage for conversation history and current ID
    conversation_store = gr.BrowserState(default_value=[], storage_key="ling_conversation_history")
    current_conversation_id = gr.BrowserState(default_value=None, storage_key="ling_current_conversation_id")

    def handle_new_chat(history, current_conv_id, lang):
        current_convo = next((c for c in history if c["id"] == current_conv_id), None) if history else None

        if current_convo and not current_convo.get("messages", []):
            return current_conv_id, history, [], gr.update(value=get_history_df(history))

        conv_id = str(uuid.uuid4())
        new_convo_title = get_text('chat_new_conversation_title', lang)
        new_convo = {
            "id": conv_id, "title": new_convo_title,
            "messages": [], "timestamp": datetime.now().isoformat()
        }
        updated_history = [new_convo] + (history or [])
        return conv_id, updated_history, [], gr.update(value=get_history_df(updated_history))

    def load_conversation_from_df(df: pd.DataFrame, evt: gr.SelectData, history, lang):
        if evt.index is None or len(df) == 0:
            return None, []
        selected_id = df.iloc[evt.index[0]]['ID']
        for convo in history:
            if convo["id"] == selected_id:
                return selected_id, convo["messages"]
        new_id, _, new_msgs, _ = handle_new_chat(history, None, lang)
        return new_id, new_msgs

    with gr.Row(equal_height=False, elem_id="indicator-chat-tab"):
        with gr.Column(scale=1):
            new_chat_btn = gr.Button(get_text('chat_new_chat_button', initial_lang))
            history_df = gr.DataFrame(
                value=get_history_df(conversation_store.value),
                headers=["ID", get_text('chat_history_dataframe_header', initial_lang)],
                datatype=["str", "str"],
                interactive=False,
                visible=True,
                column_widths=["0%", "99%"]
            )

        with gr.Column(scale=4):
            chatbot = gr.Chatbot(height=500, placeholder=get_text('chat_chatbot_placeholder', initial_lang))
            with gr.Row():
                textbox = gr.Textbox(placeholder=get_text('chat_textbox_placeholder', initial_lang), container=False, scale=7)
                submit_btn = gr.Button(get_text('chat_submit_button', initial_lang), scale=1)
            
            recommended_title = gr.Markdown(get_text('chat_recommended_dialogues_title', initial_lang))
            recommended_dataset = gr.Dataset(
                components=[gr.Textbox(visible=False)],
                samples=[[item["task"]] for item in RECOMMENDED_INPUTS],
                label=get_text('chat_recommended_dataset_label', initial_lang),
                headers=[get_text('chat_recommended_dataset_header', initial_lang)],
            )

        with gr.Column(scale=1):
            model_dropdown, model_description_markdown = create_model_selector(
                model_specs=CHAT_MODEL_SPECS,
                default_model_constant=LING_1T
            )

            system_prompt_textbox = gr.Textbox(label=get_text('chat_system_prompt_label', initial_lang), lines=5, placeholder=get_text('chat_system_prompt_placeholder', initial_lang))
            temperature_slider = gr.Slider(minimum=0, maximum=1.0, value=0.7, step=0.1, label=get_text('chat_temperature_slider_label', initial_lang))

        # --- Event Handlers --- #
        def on_select_recommendation(evt: gr.SelectData, history, current_conv_id, lang):
            selected_task = evt.value[0]
            item = next((i for i in RECOMMENDED_INPUTS if i["task"] == selected_task), None)
            if not item: return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()

            new_id, new_history, new_messages, history_df_update = handle_new_chat(history, current_conv_id, lang)

            return (
                new_id, new_history,
                gr.update(value=item["model"]),
                gr.update(value=item["system_prompt"]),
                gr.update(value=item["temperature"]),
                gr.update(value=item["user_message"]),
                history_df_update,
                new_messages
            )

        def chat_stream(conv_id, history, model_display_name, message, chat_history, system_prompt, temperature):
            if not message:
                yield chat_history
                return
            model_constant = next((k for k, v in CHAT_MODEL_SPECS.items() if v["display_name"] == model_display_name), LING_1T)
            response_generator = model_handler.get_response(model_constant, message, chat_history, system_prompt, temperature)
            for history_update in response_generator:
                yield history_update

        def on_chat_stream_complete(conv_id, history, final_chat_history, lang):
            current_convo = next((c for c in history if c["id"] == conv_id), None)
            if not current_convo:
                return history, gr.update()
            
            new_convo_title = get_text('chat_new_conversation_title', lang)
            if len(final_chat_history) > len(current_convo["messages"]) and current_convo["title"] == new_convo_title:
                user_message = final_chat_history[-2]["content"] if len(final_chat_history) > 1 else final_chat_history[0]["content"]
                current_convo["title"] = user_message[:50]

            current_convo["messages"] = final_chat_history
            current_convo["timestamp"] = datetime.now().isoformat()

            history = sorted([c for c in history if c["id"] != conv_id] + [current_convo], key=lambda x: x["timestamp"], reverse=True)
            return history, gr.update(value=get_history_df(history))
        
        # Store all components that need i18n updates
        components = {
            "new_chat_btn": new_chat_btn,
            "history_df": history_df,
            "chatbot": chatbot,
            "textbox": textbox,
            "submit_btn": submit_btn,
            "recommended_title": recommended_title,
            "recommended_dataset": recommended_dataset,
            "system_prompt_textbox": system_prompt_textbox,
            "temperature_slider": temperature_slider,
            "model_dropdown": model_dropdown,
            "model_description_markdown": model_description_markdown,
            # Non-updatable components needed for event handlers and app.py
            "conversation_store": conversation_store,
            "current_conversation_id": current_conversation_id,
        }

        # Wire event handlers
        recommended_dataset.select(on_select_recommendation, inputs=[conversation_store, current_conversation_id, current_lang_state], outputs=[current_conversation_id, conversation_store, model_dropdown, system_prompt_textbox, temperature_slider, textbox, history_df, chatbot], show_progress="none")

        submit_btn.click(
            chat_stream,
            [current_conversation_id, conversation_store, model_dropdown, textbox, chatbot, system_prompt_textbox, temperature_slider],
            [chatbot]
        ).then(
            on_chat_stream_complete,
            [current_conversation_id, conversation_store, chatbot, current_lang_state],
            [conversation_store, history_df]
        )
        textbox.submit(
            chat_stream,
            [current_conversation_id, conversation_store, model_dropdown, textbox, chatbot, system_prompt_textbox, temperature_slider],
            [chatbot]
        ).then(
            on_chat_stream_complete,
            [current_conversation_id, conversation_store, chatbot, current_lang_state],
            [conversation_store, history_df]
        )

        new_chat_btn.click(handle_new_chat, inputs=[conversation_store, current_conversation_id, current_lang_state], outputs=[current_conversation_id, conversation_store, chatbot, history_df])
        history_df.select(load_conversation_from_df, inputs=[history_df, conversation_store, current_lang_state], outputs=[current_conversation_id, chatbot])

    return components

def update_language(lang: str, components: dict):
    """
    Returns a dictionary mapping components to their gr.update calls for language change.
    """
    updates = {
        components["new_chat_btn"]: gr.update(value=get_text('chat_new_chat_button', lang)),
        components["history_df"]: gr.update(headers=["ID", get_text('chat_history_dataframe_header', lang)]),
        components["chatbot"]: gr.update(placeholder=get_text('chat_chatbot_placeholder', lang)),
        components["textbox"]: gr.update(placeholder=get_text('chat_textbox_placeholder', lang)),
        components["submit_btn"]: gr.update(value=get_text('chat_submit_button', lang)),
        components["recommended_title"]: gr.update(value=get_text('chat_recommended_dialogues_title', lang)),
        components["recommended_dataset"]: gr.update(label=get_text('chat_recommended_dataset_label', lang), headers=[get_text('chat_recommended_dataset_header', lang)]),
        components["system_prompt_textbox"]: gr.update(label=get_text('chat_system_prompt_label', lang), placeholder=get_text('chat_system_prompt_placeholder', lang)),
        components["temperature_slider"]: gr.update(label=get_text('chat_temperature_slider_label', lang)),
    }
    return updates