shmuelamar commited on
Commit
9126461
·
unverified ·
1 Parent(s): 220010b

use gpu only under decorated spaces.GPU

Browse files
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -68,7 +68,15 @@ MAX_PROMPT_TOKENS = 256
68
 
69
 
70
  @spaces.GPU
71
- def completion(prompt: str, model, tokenizer):
 
 
 
 
 
 
 
 
72
  # tokenize
73
  input_ids = tokenizer.apply_chat_template(
74
  [
@@ -93,6 +101,12 @@ def completion(prompt: str, model, tokenizer):
93
  top_p=None,
94
  temperature=None,
95
  )
 
 
 
 
 
 
96
  return tokenizer.decode(outputs[0][input_ids.shape[-1] :], skip_special_tokens=True)
97
 
98
 
@@ -107,16 +121,6 @@ def completion_openrouter(prompt: str, model_id: str):
107
  return resp.choices[0].message.content
108
 
109
 
110
- # @functools.cache
111
- def load_model_and_tokenizer(model_id: str):
112
- logger.info(f"loading local model and tokenizer for {model_id}")
113
- tokenizer = AutoTokenizer.from_pretrained(model_id)
114
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
115
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="auto")
116
- logger.info(f"done loading {model_id}")
117
- return model, tokenizer
118
-
119
-
120
  def load_openrouter_client():
121
  logger.info(f"connecting to OpenRouter")
122
  return OpenAI(
@@ -135,13 +139,7 @@ def get_completion(*, prompt: str, model_id: str):
135
  if model_id.startswith("api:"):
136
  return completion_openrouter(prompt, model_id.removeprefix("api:"))
137
  else:
138
- model, tokenizer = load_model_and_tokenizer(model_id)
139
- resp = completion(prompt, model, tokenizer)
140
-
141
- # cleanup memory
142
- del model, tokenizer
143
- torch.cuda.empty_cache()
144
- gc.collect()
145
  return resp
146
 
147
 
 
68
 
69
 
70
  @spaces.GPU
71
+ def completion(prompt: str, model_id: str):
72
+ # load model and tokenizer
73
+ logger.info(f"loading local model and tokenizer for {model_id}")
74
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
75
+
76
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
77
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="auto")
78
+ logger.info(f"done loading {model_id}")
79
+
80
  # tokenize
81
  input_ids = tokenizer.apply_chat_template(
82
  [
 
101
  top_p=None,
102
  temperature=None,
103
  )
104
+
105
+ # cleanup memory
106
+ del model, tokenizer
107
+ torch.cuda.empty_cache()
108
+ gc.collect()
109
+
110
  return tokenizer.decode(outputs[0][input_ids.shape[-1] :], skip_special_tokens=True)
111
 
112
 
 
121
  return resp.choices[0].message.content
122
 
123
 
 
 
 
 
 
 
 
 
 
 
124
  def load_openrouter_client():
125
  logger.info(f"connecting to OpenRouter")
126
  return OpenAI(
 
139
  if model_id.startswith("api:"):
140
  return completion_openrouter(prompt, model_id.removeprefix("api:"))
141
  else:
142
+ resp = completion(prompt, model_id)
 
 
 
 
 
 
143
  return resp
144
 
145