sitammeur commited on
Commit
435ab6f
·
verified ·
1 Parent(s): 63ee80c

Update src/paligemma/model.py

Browse files
Files changed (1) hide show
  1. src/paligemma/model.py +53 -53
src/paligemma/model.py CHANGED
@@ -1,53 +1,53 @@
1
- # Necessary imports
2
- import os
3
- import sys
4
- from dotenv import load_dotenv
5
- from typing import Any
6
- import torch
7
- from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
8
-
9
- # Local imports
10
- from src.logger import logging
11
- from src.exception import CustomExceptionHandling
12
-
13
-
14
- # Load the Environment Variables from .env file
15
- load_dotenv()
16
-
17
- # Access token for using the model
18
- access_token = os.environ.get("ACCESS_TOKEN")
19
-
20
-
21
- def load_model_and_processor(model_name: str, device: str) -> Any:
22
- """
23
- Load the model and processor.
24
-
25
- Args:
26
- - model_name (str): The name of the model to load.
27
- - device (str): The device to load the model onto.
28
-
29
- Returns:
30
- - model: The loaded model.
31
- - processor: The loaded processor.
32
- """
33
- try:
34
- # Load the model and processor
35
- model = (
36
- PaliGemmaForConditionalGeneration.from_pretrained(
37
- model_name, torch_dtype=torch.bfloat16, token=access_token
38
- )
39
- .eval()
40
- .to(device)
41
- )
42
- processor = PaliGemmaProcessor.from_pretrained(model_name, token=access_token)
43
-
44
- # Log the successful loading of the model and processor
45
- logging.info("Model and processor loaded successfully.")
46
-
47
- # Return the model and processor
48
- return model, processor
49
-
50
- # Handle exceptions that may occur during model and processor loading
51
- except Exception as e:
52
- # Custom exception handling
53
- raise CustomExceptionHandling(e, sys) from e
 
1
+ # Necessary imports
2
+ import os
3
+ import sys
4
+ from dotenv import load_dotenv
5
+ from typing import Any
6
+ import torch
7
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
8
+
9
+ # Local imports
10
+ from src.logger import logging
11
+ from src.exception import CustomExceptionHandling
12
+
13
+
14
+ # Load the Environment Variables from .env file
15
+ load_dotenv()
16
+
17
+ # Access token for using the model
18
+ access_token = os.environ.get("ACCESS_TOKEN")
19
+
20
+
21
+ def load_model_and_processor(model_name: str, device: str) -> Any:
22
+ """
23
+ Load the model and processor.
24
+
25
+ Args:
26
+ - model_name (str): The name of the model to load.
27
+ - device (str): The device to load the model onto.
28
+
29
+ Returns:
30
+ - model: The loaded model.
31
+ - processor: The loaded processor.
32
+ """
33
+ try:
34
+ # Load the model and processor
35
+ model = (
36
+ PaliGemmaForConditionalGeneration.from_pretrained(
37
+ model_name, dtype=torch.bfloat16, token=access_token
38
+ )
39
+ .eval()
40
+ .to(device)
41
+ )
42
+ processor = PaliGemmaProcessor.from_pretrained(model_name, use_fast=True, token=access_token)
43
+
44
+ # Log the successful loading of the model and processor
45
+ logging.info("Model and processor loaded successfully.")
46
+
47
+ # Return the model and processor
48
+ return model, processor
49
+
50
+ # Handle exceptions that may occur during model and processor loading
51
+ except Exception as e:
52
+ # Custom exception handling
53
+ raise CustomExceptionHandling(e, sys) from e