Update configuration_deepseekocr.py
Browse files- configuration_deepseekocr.py +14 -24
configuration_deepseekocr.py
CHANGED
|
@@ -2,8 +2,6 @@
|
|
| 2 |
# ------------------------------------------------------------
|
| 3 |
# Configuration class for the Deepseek-OCR model
|
| 4 |
# ------------------------------------------------------------
|
| 5 |
-
|
| 6 |
-
from transformers.configuration_utils import PretrainedConfig
|
| 7 |
from transformers.utils import logging
|
| 8 |
from .configuration_deepseek_v2 import DeepseekV2Config
|
| 9 |
|
|
@@ -15,13 +13,9 @@ class DeepseekOCRConfig(DeepseekV2Config):
|
|
| 15 |
"""
|
| 16 |
Config for Deepseek-OCR.
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
This lets DeepseekOCRModel (which subclasses DeepseekV2Model)
|
| 23 |
-
see ALL the attributes it expects (hidden_act, attention_bias, etc.)
|
| 24 |
-
while still letting us keep multimodal metadata.
|
| 25 |
"""
|
| 26 |
|
| 27 |
model_type = "deepseekocr"
|
|
@@ -35,34 +29,30 @@ class DeepseekOCRConfig(DeepseekV2Config):
|
|
| 35 |
projector_config=None,
|
| 36 |
vision_config=None,
|
| 37 |
language_config=None,
|
| 38 |
-
torch_dtype="bfloat16",
|
| 39 |
**kwargs,
|
| 40 |
):
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
num_hidden_layers, hidden_act, attention_bias, etc.) are passed via
|
| 44 |
-
**kwargs and handled by DeepseekV2Config.__init__.
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
# If a nested language_config is provided (like in your config.json),
|
| 48 |
-
# use it as a base and let top-level kwargs override it.
|
| 49 |
if language_config is not None and isinstance(language_config, dict):
|
| 50 |
-
base = dict(language_config)
|
| 51 |
-
base.update(kwargs)
|
| 52 |
kwargs = base
|
| 53 |
|
| 54 |
-
# Let DeepseekV2Config
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
# OCR-specific
|
| 58 |
self.candidate_resolutions = candidate_resolutions or [[1024, 1024]]
|
| 59 |
self.global_view_pos = global_view_pos
|
| 60 |
self.tile_tag = tile_tag
|
| 61 |
|
| 62 |
-
# Keep
|
| 63 |
self.projector_config = projector_config
|
| 64 |
self.vision_config = vision_config
|
| 65 |
self.language_config = language_config
|
| 66 |
|
| 67 |
logger.info("✅ DeepseekOCRConfig initialized (inherits DeepseekV2Config).")
|
| 68 |
|
|
|
|
|
|
| 2 |
# ------------------------------------------------------------
|
| 3 |
# Configuration class for the Deepseek-OCR model
|
| 4 |
# ------------------------------------------------------------
|
|
|
|
|
|
|
| 5 |
from transformers.utils import logging
|
| 6 |
from .configuration_deepseek_v2 import DeepseekV2Config
|
| 7 |
|
|
|
|
| 13 |
"""
|
| 14 |
Config for Deepseek-OCR.
|
| 15 |
|
| 16 |
+
Inherits all language-model fields from DeepseekV2Config
|
| 17 |
+
(hidden_size, hidden_act, attention_bias, etc.) and adds
|
| 18 |
+
OCR / vision specific metadata.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
|
| 21 |
model_type = "deepseekocr"
|
|
|
|
| 29 |
projector_config=None,
|
| 30 |
vision_config=None,
|
| 31 |
language_config=None,
|
|
|
|
| 32 |
**kwargs,
|
| 33 |
):
|
| 34 |
+
# If a nested language_config dict is provided in config.json,
|
| 35 |
+
# merge it into kwargs so DeepseekV2Config sees all LM params.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
if language_config is not None and isinstance(language_config, dict):
|
| 37 |
+
base = dict(language_config) # copy
|
| 38 |
+
base.update(kwargs) # top-level overrides nested
|
| 39 |
kwargs = base
|
| 40 |
|
| 41 |
+
# Let DeepseekV2Config handle all core model parameters.
|
| 42 |
+
# NOTE: we do NOT pass torch_dtype explicitly here, it will be
|
| 43 |
+
# picked from kwargs if present, so no "multiple values" error.
|
| 44 |
+
super().__init__(**kwargs)
|
| 45 |
|
| 46 |
+
# Store OCR-specific attributes
|
| 47 |
self.candidate_resolutions = candidate_resolutions or [[1024, 1024]]
|
| 48 |
self.global_view_pos = global_view_pos
|
| 49 |
self.tile_tag = tile_tag
|
| 50 |
|
| 51 |
+
# Keep sub-configs around for the modeling code
|
| 52 |
self.projector_config = projector_config
|
| 53 |
self.vision_config = vision_config
|
| 54 |
self.language_config = language_config
|
| 55 |
|
| 56 |
logger.info("✅ DeepseekOCRConfig initialized (inherits DeepseekV2Config).")
|
| 57 |
|
| 58 |
+
|