| | import argparse |
| | import os |
| | import importlib |
| |
|
| | def check_plugins(loaded_plugins): |
| | print("Loaded plugins:") |
| | for plugin in loaded_plugins: |
| | print(f"- {plugin}") |
| |
|
| |
|
| | def train_model(dataset_name, plugins): |
| | dataset = {'train': []} |
| |
|
| | model = "FlowModel" |
| |
|
| | for plugin in plugins: |
| | if hasattr(plugin, 'modify_model'): |
| | model = plugin.modify_model(model) |
| |
|
| | for plugin in plugins: |
| | if hasattr(plugin, 'on_train_start'): |
| | plugin.on_train_start() |
| |
|
| | print(f"Training started on dataset: {dataset_name}") |
| |
|
| | for plugin in plugins: |
| | if hasattr(plugin, 'on_train_end'): |
| | plugin.on_train_end() |
| |
|
| | print("Training finished.") |
| |
|
| |
|
| | def load_plugins(): |
| | plugins_dir = './plugins' |
| | plugins = [] |
| |
|
| | if not os.path.exists(plugins_dir): |
| | os.makedirs(plugins_dir) |
| | print(f"Plugins directory created at {plugins_dir}. Add your plugins there!") |
| |
|
| | for filename in os.listdir(plugins_dir): |
| | if filename.endswith('.py') and filename != '__init__.py': |
| | plugin_name = filename[:-3] |
| | try: |
| | plugin_module = importlib.import_module(f'plugins.{plugin_name}') |
| | plugin_class = getattr(plugin_module, plugin_name.title().replace('_', ''), None) |
| | if plugin_class: |
| | plugins.append(plugin_class()) |
| | print(f"Plugin {plugin_name} loaded.") |
| | else: |
| | print(f"No class found in plugin {plugin_name}.") |
| | except Exception as e: |
| | print(f"Failed to load plugin {plugin_name}: {e}") |
| |
|
| | return plugins |
| |
|
| |
|
| | def predict_model(plugins): |
| | print("Prediction started.") |
| | for plugin in plugins: |
| | if hasattr(plugin, 'on_predict'): |
| | plugin.on_predict() |
| | print("Prediction finished.") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="FlowModel CLI") |
| | parser.add_argument('command', choices=['train', 'predict', 'check_plugins'], help="Command to run") |
| |
|
| | args = parser.parse_args() |
| |
|
| | plugins, loaded_plugins = load_plugins() |
| |
|
| | if args.command == 'train': |
| | plugins = load_plugins() |
| | train_model("mnist", plugins) |
| | elif args.command == 'predict': |
| | predict_model(plugins) |
| | elif args.command == 'check_plugins': |
| | check_plugins(loaded_plugins) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|