Coverage for birdplan/plugin.py: 80%
108 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-23 03:27 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-23 03:27 +0000
1#
2# SPDX-License-Identifier: GPL-3.0-or-later
3#
4# Copyright (C) 2019-2024, AllWorldIT.
5#
6# This program is free software: you can redistribute it and/or modify
7# it under the terms of the GNU General Public License as published by
8# the Free Software Foundation, either version 3 of the License, or
9# (at your option) any later version.
10#
11# This program is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14# GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License
17# along with this program. If not, see <https://www.gnu.org/licenses/>.
19"""Plugin handler."""
21import inspect
22import logging
23import os
24import pkgutil
25from typing import Any, Dict, List, Optional
27__all__ = ["PluginMethodException", "PluginNotFoundException", "Plugin", "PluginCollection"]
30class PluginMethodException(RuntimeError):
31 """Plugin method exception raised when a method is called that does not exist."""
34class PluginNotFoundException(RuntimeError):
35 """Plugin not found exception raised when a plugin is referenced by name and not found."""
38class Plugin: # pylint: disable=too-few-public-methods
39 """Base plugin class, used as the parent for all plugins we define."""
41 plugin_description: str
42 plugin_order: int
44 def __init__(self) -> None:
45 """Plugin __init__ method."""
47 # Set defaults
48 self.plugin_description = type(self).__name__
49 self.plugin_order = 10
52class PluginCollection:
53 """
54 Initialize PluginCollection using a plugin base package.
56 Apon loading each plugin will be instantiated as an object.
58 Parameters
59 ----------
60 plugin_package : str
61 Source plan file to generate configuration from.
63 """
65 # The package name we will be loading plugins from
66 _plugin_packages: List[str]
67 # List of plugins we've loaded
68 _plugins: Dict[str, Plugin]
69 # List of paths we've seen during processing
70 _seen_paths: List[str]
71 # Plugin statuses
72 _plugin_status: Dict[str, str]
74 def __init__(self, plugin_packages: List[str]):
75 """
76 Initialize Plugincollection using a plugin base package.
78 Classes with a name ending in 'Base' will not be loaded.
80 Parameters
81 ----------
82 plugin_packages : List[str]
83 Package names to load plugins from.
85 """
87 # Setup object
88 self._plugin_packages = plugin_packages
89 self._plugins = {}
90 self._seen_paths = []
91 self._plugin_status = {}
93 # Load plugins
94 self._load_plugins()
96 def call_if_exists(self, method_name: str, args: Any = None) -> Dict[str, Any]:
97 """
98 Call a plugin method, but do not raise an exception if it does not exist.
100 Parameters
101 ----------
102 method_name : str
103 Method name to call.
105 args : Any
106 Method argument(s).
108 Returns
109 -------
110 Dict containing the module name and its result.
112 """
114 logging.debug("Calling method '%s' if exists", method_name)
116 return self.call(method_name, args, skip_not_found=True)
118 def call(self, method_name: str, args: Any = None, skip_not_found: bool = False) -> Dict[str, Any]:
119 """
120 Call a plugin method.
122 Parameters
123 ----------
124 method_name : str
125 Method name to call.
127 kwargs : Any
128 Method arguments.
130 args : Any
131 Method argument(s).
133 skip_not_found :
134 If the method is not found return None.
136 Returns
137 -------
138 Dict containing the module name and its result.
140 """
142 # Loop with plugins, if they have overridden the method, then call it
143 results = {}
144 # Loop through plugins sorted
145 for plugin_name, plugin in sorted(self.plugins.items(), key=lambda kv: kv[1].plugin_order):
146 # Check if we're going to raise an exception or just skip
147 if not hasattr(plugin, method_name):
148 if skip_not_found:
149 logging.debug("Method '%s' does not exist in plugin '%s'", method_name, plugin_name)
150 continue
151 raise PluginMethodException(f'Plugin "{plugin_name}" has no method "{method_name}"')
152 # Save the result
153 results[plugin_name] = self.call_plugin(plugin_name, method_name, args)
155 return results
157 def get_first(self, method_name: str) -> Optional[str]:
158 """
159 Get the first plugin method found that matches a specific method name.
161 Parameters
162 ----------
163 method_name : str
164 Method name to call.
166 Returns
167 -------
168 Any containing the result.
170 """
172 # Loop through plugins sorted
173 for plugin_name, plugin in sorted(self.plugins.items(), key=lambda kv: kv[1].plugin_order):
174 # Check if we're skipping this one if the method is not found
175 if not hasattr(plugin, method_name):
176 continue
177 # Return the first result we get
178 return plugin_name
180 return None
182 def call_first(self, method_name: str, args: Any = None) -> Any:
183 """
184 Call the first plugin method found.
186 Parameters
187 ----------
188 method_name : str
189 Method name to call.
191 kwargs : Any
192 Method arguments.
194 args : Any
195 Method argument(s).
197 Returns
198 -------
199 Any containing the result.
201 """
203 # Get first plugin which has our method
204 plugin_name = self.get_first(method_name)
206 # Make sure we got a plugin back
207 if not plugin_name:
208 raise PluginNotFoundException(f"No plugin found for method name '{method_name}'")
210 # Return the result of the method call on the first plugin
211 return self.call_plugin(plugin_name, method_name, args)
213 def call_plugin(self, plugin_name: str, method_name: str, args: Any = None) -> Any:
214 """
215 Call a specific plugin and its method.
217 Parameters
218 ----------
219 plugin_name : str
220 Plugin to call the method in.
222 method_name : str
223 Method name to call.
225 args : Any
226 Method argument(s).
228 Returns
229 -------
230 Any containing the plugin call result.
232 """
234 # Check if plugin exists
235 if plugin_name not in self.plugins:
236 raise PluginNotFoundException(f'Plugin "{plugin_name}"" not found')
237 # If it does then grab it
238 plugin = self.plugins[plugin_name]
240 # Check if we're going to raise an exception or just skip
241 if not hasattr(plugin, method_name):
242 raise PluginMethodException(f'Plugin "{plugin_name}" has no method "{method_name}"')
244 # Grab the method
245 method = getattr(plugin, method_name)
247 # Call it
248 logging.debug("Calling method '%s' from plugin '%s'", method_name, plugin_name)
249 return method(args)
251 def get(self, plugin_name: str) -> Plugin:
252 """
253 Get a specific plugin object.
255 Parameters
256 ----------
257 plugin_name : str
258 Plugin to call the method in.
260 Returns
261 -------
262 Plugin object.
264 """
266 if plugin_name not in self.plugins:
267 raise PluginNotFoundException(f'Plugin "{plugin_name}" not found')
269 return self.plugins[plugin_name]
271 #
272 # Internals
273 #
275 def _load_plugins(self) -> None:
276 """Load plugins from the plugin_package we were provided."""
278 # Load plugin packages
279 for plugin_package in self._plugin_packages:
280 self._find_plugins(plugin_package)
282 def _find_plugins(self, package_name: str) -> None: # pylint: disable=too-many-branches
283 """
284 Recursively search the plugin_package and retrieve all plugins.
286 Parameters
287 ----------
288 package_name : str
289 Package to load plugins from.
291 """
293 logging.debug("Finding plugins from '%s'", package_name)
295 imported_package = __import__(package_name, fromlist=["__VERSION__"])
297 # Iterate through the modules
298 for _, plugin_name, _ in pkgutil.iter_modules(imported_package.__path__, imported_package.__name__ + "."):
299 # Try import
300 try:
301 plugin_module = __import__(plugin_name, fromlist=["__VERSION__"])
302 except ModuleNotFoundError as err:
303 self._plugin_status[plugin_name] = f"cannot load module: {err}"
304 continue
306 # Grab object members
307 object_members = inspect.getmembers(plugin_module, inspect.isclass)
309 # Loop with class names
310 for _, class_name in object_members:
311 # Only add classes that are a sub class of Plugin
312 if not issubclass(class_name, Plugin) or (class_name is Plugin) or class_name.__name__.endswith("Base"):
313 continue
314 # Save plugin and record that it was loaded
315 self._plugins[plugin_name] = class_name()
316 self._plugin_status[plugin_name] = "loaded"
317 logging.debug("Plugin loaded '%s' [class=%s]", plugin_name, class_name)
319 # Look for modules in sub packages
320 all_current_paths: List[str] = []
322 if isinstance(imported_package.__path__, str):
323 all_current_paths.append(imported_package.__path__)
324 else:
325 all_current_paths.extend(imported_package.__path__)
327 # Loop with package path
328 for pkg_path in all_current_paths:
329 # Make sure its not seen in our seen_paths
330 if pkg_path in self._seen_paths:
331 continue
332 # If not add it so we don't process it again
333 self._seen_paths.append(pkg_path)
335 # Grab all the sub directories of the current package path directory
336 sub_dirs = []
337 for sub_dir in os.listdir(pkg_path):
338 # If the subdir starts with a ., ignore it
339 if sub_dir.startswith("."):
340 continue
341 # If the subdir is __pycache__, ignore it
342 if sub_dir == "__pycache__":
343 continue
344 # If this is not a sub dir, then move onto the next one
345 if not os.path.isdir(os.path.join(pkg_path, sub_dir)):
346 continue
347 # Add sub-directory
348 sub_dirs.append(sub_dir)
350 # Find packages in sub directory
351 for sub_dir in sub_dirs:
352 module = f"{package_name}.{sub_dir}"
353 self._find_plugins(module)
355 @property
356 def plugins(self) -> Dict[str, Plugin]:
357 """
358 Property containing the dictionary of plugins loaded.
360 Returns
361 -------
362 Dict[str, Plugin], keyed by plugin name.
364 """
366 return self._plugins
368 @property
369 def plugin_status(self) -> Dict[str, str]:
370 """
371 Property containing the plugin load status.
373 Returns
374 -------
375 Dict[str, str], keyed by plugin name.
377 """
379 return self._plugin_status