Update palmapi.py
Browse files- palmapi.py +42 -8
palmapi.py
CHANGED
|
@@ -17,7 +17,39 @@ if palm_api_token is None:
|
|
| 17 |
else:
|
| 18 |
palm_api.configure(api_key=palm_api_token)
|
| 19 |
|
| 20 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
|
| 22 |
if to_idx == -1 or to_idx >= len(self.pingpongs):
|
| 23 |
to_idx = len(self.pingpongs)
|
|
@@ -47,15 +79,17 @@ def gen_text(
|
|
| 47 |
'top_p': top_p,
|
| 48 |
}
|
| 49 |
|
| 50 |
-
if palm is None:
|
| 51 |
-
response = palm_api.chat(**parameters, messages=[prompt])
|
| 52 |
-
else:
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
| 58 |
|
|
|
|
|
|
|
| 59 |
if len(response.filters) > 0 and \
|
| 60 |
response.filters[0]['reason'] == 2:
|
| 61 |
response_txt = "your request is blocked for some reasons"
|
|
|
|
| 17 |
else:
|
| 18 |
palm_api.configure(api_key=palm_api_token)
|
| 19 |
|
| 20 |
+
class PaLMChatPromptFmt(PromptFmt):
|
| 21 |
+
@classmethod
|
| 22 |
+
def ctx(cls, context):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def prompt(cls, pingpong, truncate_size):
|
| 27 |
+
ping = pingpong.ping[:truncate_size]
|
| 28 |
+
pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
|
| 29 |
+
return [
|
| 30 |
+
{
|
| 31 |
+
"author": "USER",
|
| 32 |
+
"content": ping
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"author": "AI",
|
| 36 |
+
"content": pong
|
| 37 |
+
},
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
class PaLMChatPPManager(PPManager):
|
| 41 |
+
def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=PaLMChatPromptFmt, truncate_size: int=None):
|
| 42 |
+
results = []
|
| 43 |
+
|
| 44 |
+
if to_idx == -1 or to_idx >= len(self.pingpongs):
|
| 45 |
+
to_idx = len(self.pingpongs)
|
| 46 |
+
|
| 47 |
+
for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
|
| 48 |
+
results += fmt.prompt(pingpong, truncate_size=truncate_size)
|
| 49 |
+
|
| 50 |
+
return results
|
| 51 |
+
|
| 52 |
+
class GradioPaLMChatPPManager(PaLMChatPPManager):
|
| 53 |
def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
|
| 54 |
if to_idx == -1 or to_idx >= len(self.pingpongs):
|
| 55 |
to_idx = len(self.pingpongs)
|
|
|
|
| 79 |
'top_p': top_p,
|
| 80 |
}
|
| 81 |
|
| 82 |
+
# if palm is None:
|
| 83 |
+
# response = palm_api.chat(**parameters, messages=[prompt])
|
| 84 |
+
# else:
|
| 85 |
+
# palm.temperature = parameters['temperature']
|
| 86 |
+
# palm.top_k = parameters['top_k']
|
| 87 |
+
# palm.top_p = parameters['top_p']
|
| 88 |
|
| 89 |
+
# response = palm.reply(prompt)
|
| 90 |
|
| 91 |
+
response = palm_api.chat(**parameters, messages=prompt)
|
| 92 |
+
|
| 93 |
if len(response.filters) > 0 and \
|
| 94 |
response.filters[0]['reason'] == 2:
|
| 95 |
response_txt = "your request is blocked for some reasons"
|