AUXteam commited on
Commit
2dea9e7
·
verified ·
1 Parent(s): 68249b4

Upload ai_scientist/llm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai_scientist/llm.py +11 -2
ai_scientist/llm.py CHANGED
@@ -59,6 +59,9 @@ AVAILABLE_LLMS = [
59
  "gemini-2.0-flash-thinking-exp-01-21",
60
  "gemini-2.5-pro-preview-03-25",
61
  "gemini-2.5-pro-exp-03-25",
 
 
 
62
  ]
63
 
64
 
@@ -77,7 +80,7 @@ def get_batch_responses_from_llm(
77
  if msg_history is None:
78
  msg_history = []
79
 
80
- if 'gpt' in model:
81
  new_msg_history = msg_history + [{"role": "user", "content": msg}]
82
  response = client.chat.completions.create(
83
  model=model,
@@ -183,7 +186,7 @@ def get_response_from_llm(
183
  ],
184
  }
185
  ]
186
- elif 'gpt' in model:
187
  new_msg_history = msg_history + [{"role": "user", "content": msg}]
188
  response = client.chat.completions.create(
189
  model=model,
@@ -347,5 +350,11 @@ def create_client(model):
347
  api_key=os.environ["GEMINI_API_KEY"],
348
  base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
349
  ), model
 
 
 
 
 
 
350
  else:
351
  raise ValueError(f"Model {model} not supported.")
 
59
  "gemini-2.0-flash-thinking-exp-01-21",
60
  "gemini-2.5-pro-preview-03-25",
61
  "gemini-2.5-pro-exp-03-25",
62
+ # Helmholtz-Blablador models
63
+ "alias-large",
64
+ "alias-fast",
65
  ]
66
 
67
 
 
80
  if msg_history is None:
81
  msg_history = []
82
 
83
+ if 'gpt' in model or model in ["alias-large", "alias-fast"]:
84
  new_msg_history = msg_history + [{"role": "user", "content": msg}]
85
  response = client.chat.completions.create(
86
  model=model,
 
186
  ],
187
  }
188
  ]
189
+ elif 'gpt' in model or model in ["alias-large", "alias-fast"]:
190
  new_msg_history = msg_history + [{"role": "user", "content": msg}]
191
  response = client.chat.completions.create(
192
  model=model,
 
350
  api_key=os.environ["GEMINI_API_KEY"],
351
  base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
352
  ), model
353
+ elif model in ["alias-large", "alias-fast"]:
354
+ print(f"Using Helmholtz-Blablador with model {model}.")
355
+ return openai.OpenAI(
356
+ api_key=os.environ["BLABLADOR_API_KEY"],
357
+ base_url="https://api.helmholtz-blablador.fz-juelich.de/v1"
358
+ ), model
359
  else:
360
  raise ValueError(f"Model {model} not supported.")