Coverage for birdplan/plugin.py: 80%

108 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-23 03:27 +0000

1# 

2# SPDX-License-Identifier: GPL-3.0-or-later 

3# 

4# Copyright (C) 2019-2024, AllWorldIT. 

5# 

6# This program is free software: you can redistribute it and/or modify 

7# it under the terms of the GNU General Public License as published by 

8# the Free Software Foundation, either version 3 of the License, or 

9# (at your option) any later version. 

10# 

11# This program is distributed in the hope that it will be useful, 

12# but WITHOUT ANY WARRANTY; without even the implied warranty of 

13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

14# GNU General Public License for more details. 

15# 

16# You should have received a copy of the GNU General Public License 

17# along with this program. If not, see <https://www.gnu.org/licenses/>. 

18 

19"""Plugin handler.""" 

20 

21import inspect 

22import logging 

23import os 

24import pkgutil 

25from typing import Any, Dict, List, Optional 

26 

27__all__ = ["PluginMethodException", "PluginNotFoundException", "Plugin", "PluginCollection"] 

28 

29 

30class PluginMethodException(RuntimeError): 

31 """Plugin method exception raised when a method is called that does not exist.""" 

32 

33 

34class PluginNotFoundException(RuntimeError): 

35 """Plugin not found exception raised when a plugin is referenced by name and not found.""" 

36 

37 

38class Plugin: # pylint: disable=too-few-public-methods 

39 """Base plugin class, used as the parent for all plugins we define.""" 

40 

41 plugin_description: str 

42 plugin_order: int 

43 

44 def __init__(self) -> None: 

45 """Plugin __init__ method.""" 

46 

47 # Set defaults 

48 self.plugin_description = type(self).__name__ 

49 self.plugin_order = 10 

50 

51 

52class PluginCollection: 

53 """ 

54 Initialize PluginCollection using a plugin base package. 

55 

56 Apon loading each plugin will be instantiated as an object. 

57 

58 Parameters 

59 ---------- 

60 plugin_package : str 

61 Source plan file to generate configuration from. 

62 

63 """ 

64 

65 # The package name we will be loading plugins from 

66 _plugin_packages: List[str] 

67 # List of plugins we've loaded 

68 _plugins: Dict[str, Plugin] 

69 # List of paths we've seen during processing 

70 _seen_paths: List[str] 

71 # Plugin statuses 

72 _plugin_status: Dict[str, str] 

73 

74 def __init__(self, plugin_packages: List[str]): 

75 """ 

76 Initialize Plugincollection using a plugin base package. 

77 

78 Classes with a name ending in 'Base' will not be loaded. 

79 

80 Parameters 

81 ---------- 

82 plugin_packages : List[str] 

83 Package names to load plugins from. 

84 

85 """ 

86 

87 # Setup object 

88 self._plugin_packages = plugin_packages 

89 self._plugins = {} 

90 self._seen_paths = [] 

91 self._plugin_status = {} 

92 

93 # Load plugins 

94 self._load_plugins() 

95 

96 def call_if_exists(self, method_name: str, args: Any = None) -> Dict[str, Any]: 

97 """ 

98 Call a plugin method, but do not raise an exception if it does not exist. 

99 

100 Parameters 

101 ---------- 

102 method_name : str 

103 Method name to call. 

104 

105 args : Any 

106 Method argument(s). 

107 

108 Returns 

109 ------- 

110 Dict containing the module name and its result. 

111 

112 """ 

113 

114 logging.debug("Calling method '%s' if exists", method_name) 

115 

116 return self.call(method_name, args, skip_not_found=True) 

117 

118 def call(self, method_name: str, args: Any = None, skip_not_found: bool = False) -> Dict[str, Any]: 

119 """ 

120 Call a plugin method. 

121 

122 Parameters 

123 ---------- 

124 method_name : str 

125 Method name to call. 

126 

127 kwargs : Any 

128 Method arguments. 

129 

130 args : Any 

131 Method argument(s). 

132 

133 skip_not_found : 

134 If the method is not found return None. 

135 

136 Returns 

137 ------- 

138 Dict containing the module name and its result. 

139 

140 """ 

141 

142 # Loop with plugins, if they have overridden the method, then call it 

143 results = {} 

144 # Loop through plugins sorted 

145 for plugin_name, plugin in sorted(self.plugins.items(), key=lambda kv: kv[1].plugin_order): 

146 # Check if we're going to raise an exception or just skip 

147 if not hasattr(plugin, method_name): 

148 if skip_not_found: 

149 logging.debug("Method '%s' does not exist in plugin '%s'", method_name, plugin_name) 

150 continue 

151 raise PluginMethodException(f'Plugin "{plugin_name}" has no method "{method_name}"') 

152 # Save the result 

153 results[plugin_name] = self.call_plugin(plugin_name, method_name, args) 

154 

155 return results 

156 

157 def get_first(self, method_name: str) -> Optional[str]: 

158 """ 

159 Get the first plugin method found that matches a specific method name. 

160 

161 Parameters 

162 ---------- 

163 method_name : str 

164 Method name to call. 

165 

166 Returns 

167 ------- 

168 Any containing the result. 

169 

170 """ 

171 

172 # Loop through plugins sorted 

173 for plugin_name, plugin in sorted(self.plugins.items(), key=lambda kv: kv[1].plugin_order): 

174 # Check if we're skipping this one if the method is not found 

175 if not hasattr(plugin, method_name): 

176 continue 

177 # Return the first result we get 

178 return plugin_name 

179 

180 return None 

181 

182 def call_first(self, method_name: str, args: Any = None) -> Any: 

183 """ 

184 Call the first plugin method found. 

185 

186 Parameters 

187 ---------- 

188 method_name : str 

189 Method name to call. 

190 

191 kwargs : Any 

192 Method arguments. 

193 

194 args : Any 

195 Method argument(s). 

196 

197 Returns 

198 ------- 

199 Any containing the result. 

200 

201 """ 

202 

203 # Get first plugin which has our method 

204 plugin_name = self.get_first(method_name) 

205 

206 # Make sure we got a plugin back 

207 if not plugin_name: 

208 raise PluginNotFoundException(f"No plugin found for method name '{method_name}'") 

209 

210 # Return the result of the method call on the first plugin 

211 return self.call_plugin(plugin_name, method_name, args) 

212 

213 def call_plugin(self, plugin_name: str, method_name: str, args: Any = None) -> Any: 

214 """ 

215 Call a specific plugin and its method. 

216 

217 Parameters 

218 ---------- 

219 plugin_name : str 

220 Plugin to call the method in. 

221 

222 method_name : str 

223 Method name to call. 

224 

225 args : Any 

226 Method argument(s). 

227 

228 Returns 

229 ------- 

230 Any containing the plugin call result. 

231 

232 """ 

233 

234 # Check if plugin exists 

235 if plugin_name not in self.plugins: 

236 raise PluginNotFoundException(f'Plugin "{plugin_name}"" not found') 

237 # If it does then grab it 

238 plugin = self.plugins[plugin_name] 

239 

240 # Check if we're going to raise an exception or just skip 

241 if not hasattr(plugin, method_name): 

242 raise PluginMethodException(f'Plugin "{plugin_name}" has no method "{method_name}"') 

243 

244 # Grab the method 

245 method = getattr(plugin, method_name) 

246 

247 # Call it 

248 logging.debug("Calling method '%s' from plugin '%s'", method_name, plugin_name) 

249 return method(args) 

250 

251 def get(self, plugin_name: str) -> Plugin: 

252 """ 

253 Get a specific plugin object. 

254 

255 Parameters 

256 ---------- 

257 plugin_name : str 

258 Plugin to call the method in. 

259 

260 Returns 

261 ------- 

262 Plugin object. 

263 

264 """ 

265 

266 if plugin_name not in self.plugins: 

267 raise PluginNotFoundException(f'Plugin "{plugin_name}" not found') 

268 

269 return self.plugins[plugin_name] 

270 

271 # 

272 # Internals 

273 # 

274 

275 def _load_plugins(self) -> None: 

276 """Load plugins from the plugin_package we were provided.""" 

277 

278 # Load plugin packages 

279 for plugin_package in self._plugin_packages: 

280 self._find_plugins(plugin_package) 

281 

282 def _find_plugins(self, package_name: str) -> None: # pylint: disable=too-many-branches 

283 """ 

284 Recursively search the plugin_package and retrieve all plugins. 

285 

286 Parameters 

287 ---------- 

288 package_name : str 

289 Package to load plugins from. 

290 

291 """ 

292 

293 logging.debug("Finding plugins from '%s'", package_name) 

294 

295 imported_package = __import__(package_name, fromlist=["__VERSION__"]) 

296 

297 # Iterate through the modules 

298 for _, plugin_name, _ in pkgutil.iter_modules(imported_package.__path__, imported_package.__name__ + "."): 

299 # Try import 

300 try: 

301 plugin_module = __import__(plugin_name, fromlist=["__VERSION__"]) 

302 except ModuleNotFoundError as err: 

303 self._plugin_status[plugin_name] = f"cannot load module: {err}" 

304 continue 

305 

306 # Grab object members 

307 object_members = inspect.getmembers(plugin_module, inspect.isclass) 

308 

309 # Loop with class names 

310 for _, class_name in object_members: 

311 # Only add classes that are a sub class of Plugin 

312 if not issubclass(class_name, Plugin) or (class_name is Plugin) or class_name.__name__.endswith("Base"): 

313 continue 

314 # Save plugin and record that it was loaded 

315 self._plugins[plugin_name] = class_name() 

316 self._plugin_status[plugin_name] = "loaded" 

317 logging.debug("Plugin loaded '%s' [class=%s]", plugin_name, class_name) 

318 

319 # Look for modules in sub packages 

320 all_current_paths: List[str] = [] 

321 

322 if isinstance(imported_package.__path__, str): 

323 all_current_paths.append(imported_package.__path__) 

324 else: 

325 all_current_paths.extend(imported_package.__path__) 

326 

327 # Loop with package path 

328 for pkg_path in all_current_paths: 

329 # Make sure its not seen in our seen_paths 

330 if pkg_path in self._seen_paths: 

331 continue 

332 # If not add it so we don't process it again 

333 self._seen_paths.append(pkg_path) 

334 

335 # Grab all the sub directories of the current package path directory 

336 sub_dirs = [] 

337 for sub_dir in os.listdir(pkg_path): 

338 # If the subdir starts with a ., ignore it 

339 if sub_dir.startswith("."): 

340 continue 

341 # If the subdir is __pycache__, ignore it 

342 if sub_dir == "__pycache__": 

343 continue 

344 # If this is not a sub dir, then move onto the next one 

345 if not os.path.isdir(os.path.join(pkg_path, sub_dir)): 

346 continue 

347 # Add sub-directory 

348 sub_dirs.append(sub_dir) 

349 

350 # Find packages in sub directory 

351 for sub_dir in sub_dirs: 

352 module = f"{package_name}.{sub_dir}" 

353 self._find_plugins(module) 

354 

355 @property 

356 def plugins(self) -> Dict[str, Plugin]: 

357 """ 

358 Property containing the dictionary of plugins loaded. 

359 

360 Returns 

361 ------- 

362 Dict[str, Plugin], keyed by plugin name. 

363 

364 """ 

365 

366 return self._plugins 

367 

368 @property 

369 def plugin_status(self) -> Dict[str, str]: 

370 """ 

371 Property containing the plugin load status. 

372 

373 Returns 

374 ------- 

375 Dict[str, str], keyed by plugin name. 

376 

377 """ 

378 

379 return self._plugin_status