Spaces:
Runtime error
Runtime error
| # Copyright (c) 2024 Microsoft Corporation. | |
| # Licensed under the MIT Licenses | |
| import json | |
| import os | |
| import unittest | |
| from pathlib import Path | |
| from typing import Any | |
| from unittest import mock | |
| from graphrag.config import create_graphrag_config | |
| from graphrag.index import ( | |
| PipelineConfig, | |
| create_pipeline_config, | |
| load_pipeline_config, | |
| ) | |
| current_dir = os.path.dirname(__file__) | |
| class TestLoadPipelineConfig(unittest.TestCase): | |
| def test_config_passed_in_returns_config(self): | |
| config = PipelineConfig() | |
| result = load_pipeline_config(config) | |
| assert result == config | |
| def test_loading_default_config_returns_config(self): | |
| result = load_pipeline_config("default") | |
| self.assert_is_default_config(result) | |
| def test_loading_default_config_with_input_overridden(self): | |
| config = load_pipeline_config( | |
| str(Path(current_dir) / "default_config_with_overridden_input.yml") | |
| ) | |
| # Check that the config is merged | |
| # but skip checking the input | |
| self.assert_is_default_config(config, check_input=False) | |
| if config.input is None: | |
| msg = "Input should not be none" | |
| raise Exception(msg) | |
| # Check that the input is merged | |
| assert config.input.file_pattern == "test.txt" | |
| assert config.input.file_type == "text" | |
| assert config.input.base_dir == "/some/overridden/dir" | |
| def test_loading_default_config_with_workflows_overridden(self): | |
| config = load_pipeline_config( | |
| str(Path(current_dir) / "default_config_with_overridden_workflows.yml") | |
| ) | |
| # Check that the config is merged | |
| # but skip checking the input | |
| self.assert_is_default_config(config, check_workflows=False) | |
| # Make sure the workflows are overridden | |
| assert len(config.workflows) == 1 | |
| assert config.workflows[0].name == "TEST_WORKFLOW" | |
| assert config.workflows[0].steps is not None | |
| assert len(config.workflows[0].steps) == 1 # type: ignore | |
| assert config.workflows[0].steps[0]["verb"] == "TEST_VERB" # type: ignore | |
| def assert_is_default_config( | |
| self, | |
| config: Any, | |
| check_input=True, | |
| check_storage=True, | |
| check_reporting=True, | |
| check_cache=True, | |
| check_workflows=True, | |
| ): | |
| assert config is not None | |
| assert isinstance(config, PipelineConfig) | |
| checked_config = json.loads( | |
| config.model_dump_json(exclude_defaults=True, exclude_unset=True) | |
| ) | |
| actual_default_config = json.loads( | |
| create_pipeline_config( | |
| create_graphrag_config(root_dir=".") | |
| ).model_dump_json(exclude_defaults=True, exclude_unset=True) | |
| ) | |
| props_to_ignore = ["root_dir", "extends"] | |
| # Make sure there is some sort of workflows | |
| if not check_workflows: | |
| props_to_ignore.append("workflows") | |
| # Make sure it tries to load some sort of input | |
| if not check_input: | |
| props_to_ignore.append("input") | |
| # Make sure it tries to load some sort of storage | |
| if not check_storage: | |
| props_to_ignore.append("storage") | |
| # Make sure it tries to load some sort of reporting | |
| if not check_reporting: | |
| props_to_ignore.append("reporting") | |
| # Make sure it tries to load some sort of cache | |
| if not check_cache: | |
| props_to_ignore.append("cache") | |
| for prop in props_to_ignore: | |
| checked_config.pop(prop, None) | |
| actual_default_config.pop(prop, None) | |
| assert actual_default_config == actual_default_config | checked_config | |
| def setUp(self) -> None: | |
| os.environ["GRAPHRAG_OPENAI_API_KEY"] = "test" | |
| os.environ["GRAPHRAG_OPENAI_EMBEDDING_API_KEY"] = "test" | |
| return super().setUp() | |