Source code for dffml.util.cli.cmd

# SPDX-License-Identifier: MIT
# Copyright (c) 2019 Intel Corporation
import sys
import json
import uuid
import enum
import logging
import inspect
import asyncio
import argparse
from typing import Dict, Any
import dataclasses

from ...record import Record
from ...feature import Feature

from ..data import export_dict
from .arg import Arg, parse_unknown
from ...base import config, mkarg, field
from ...configloader.configloader import ConfigLoaders

DisplayHelp = "Display help message"


[docs]class CMDOutputOverride: """ Override dumping of results """
if sys.platform == "win32": # pragma: no cov asyncio.set_event_loop(asyncio.ProactorEventLoop())
[docs]class ParseLoggingAction(argparse.Action): def __call__(self, parser, namespace, value, option_string=None): setattr( namespace, self.dest, getattr(logging, value.upper(), logging.INFO) ) logging.basicConfig(level=getattr(namespace, self.dest))
[docs]class JSONEncoder(json.JSONEncoder): """ Encodes dffml types to JSON representation. """
[docs] def default(self, obj): typename_lower = str(type(obj)).lower() if isinstance(obj, Record): return obj.dict() elif isinstance(obj, uuid.UUID): return str(obj) elif isinstance(obj, Feature): return obj.name elif isinstance(obj, enum.Enum): return str(obj.value) elif isinstance(obj, type): return str(obj.__qualname__) elif "numpy." in typename_lower: if ".int" in typename_lower or ".uint" in typename_lower: return int(obj) elif typename_lower.startswith("float"): return float(obj) elif "ndarray" in typename_lower: return obj.tolist() elif str(obj).startswith("typing."): return str(obj).split(".")[-1] return json.JSONEncoder.default(self, obj)
log_cmd = Arg( "-log", help="Logging level", action=ParseLoggingAction, required=False, default=logging.INFO, )
[docs]class Parser(argparse.ArgumentParser):
[docs] def add_subs(self, add_from: "CMD"): """ Add sub commands and arguments recursively """ # Only one subparser should be created even if multiple sub commands subparsers = None for name, method in [ (name.lower().replace("_", ""), method) for name, method in inspect.getmembers(add_from) ]: if inspect.isclass(method) and issubclass(method, CMD): if subparsers is None: # pragma: no cover subparsers = self.add_subparsers() # pragma: no cover parser = subparsers.add_parser( name, description=None if method.__doc__ is None else method.__doc__, formatter_class=getattr( method, "CLI_FORMATTER_CLASS", argparse.ArgumentDefaultsHelpFormatter, ), ) parser.set_defaults(cmd=method) parser.set_defaults(parser=parser) parser.add_subs(method) # type: ignore # Add arguments to the Parser position_list = {} for i, field in enumerate(dataclasses.fields(add_from.CONFIG)): arg = mkarg(field) if isinstance(arg, Arg): position = None if not "default" in arg and not arg.get("required", False): position_list[i] = (field.name, arg) else: try: self.add_argument( "-" + field.name.replace("_", "-"), **arg ) except argparse.ArgumentError as error: raise Exception(repr(add_from)) from error if position_list: for position in sorted(position_list.keys()): name, positional_arg = position_list[position] self.add_argument(name.replace("_", "-"), **positional_arg) position_list.clear() # Add `-log` argument if it's not already added try: self.add_argument(log_cmd.name, **log_cmd) except argparse.ArgumentError: pass
[docs]@config class CMDConfig: log: str = field( "Logging Level", default=logging.INFO, required=False, action=ParseLoggingAction, )
class CMD(object): JSONEncoder = JSONEncoder EXTRA_CONFIG_ARGS = {} CONFIG = CMDConfig ENTRY_POINT_NAME = ["service"] def __init__(self, extra_config=None, **kwargs) -> None: if not hasattr(self, "logger"): self.logger = logging.getLogger( "%s.%s" % (self.__class__.__module__, self.__class__.__qualname__) ) if extra_config is None: extra_config = {} self.extra_config = extra_config for field in dataclasses.fields(self.CONFIG): arg = mkarg(field) if isinstance(arg, Arg): if not field.name in kwargs and "default" in arg: kwargs[field.name] = arg["default"] if field.name in kwargs and not hasattr(self, field.name): self.logger.debug( "Setting %s = %r", field.name, kwargs[field.name] ) setattr(self, field.name, kwargs[field.name]) else: self.logger.debug("Ignored %s", field.name) async def __aenter__(self): pass async def __aexit__(self, exc_type, exc_value, traceback): pass @classmethod async def parse_args(cls, *args): parser = Parser( description=cls.__doc__, formatter_class=getattr( cls, "CLI_FORMATTER_CLASS", argparse.ArgumentDefaultsHelpFormatter, ), ) parser.add_subs(cls) return parser, parser.parse_known_args(args) async def do_run(self): async with self: if inspect.isasyncgenfunction(self.run): return [res async for res in self.run()] else: return await self.run() @classmethod async def cli(cls, *args): parser, (args, unknown) = await cls.parse_args(*args) async with ConfigLoaders() as configloaders: args.extra_config = await parse_unknown( *unknown, configloaders=configloaders ) if ( getattr(cls, "run", None) is not None and getattr(args, "cmd", None) is None ): args.cmd = cls if getattr(args, "cmd", None) is None: parser.print_help() return DisplayHelp if not inspect.isfunction(getattr(args.cmd, "run", None)): args.parser.print_help() return DisplayHelp cmd = args.cmd(**cls.sanitize_args(vars(args))) return await cmd.do_run() @classmethod def sanitize_args(cls, args): """ Remove CMD internals from arguments passed to subclasses of CMD. """ for rm in ["cmd", "parser", "log"]: if rm in args: del args[rm] return args @classmethod async def _main(cls, *args): return await cls.cli(*args) @classmethod def main(cls, loop=None, argv=sys.argv): """ Runs cli commands in asyncio loop and outputs in appropriate format """ if loop is None: # In order to use asyncio.subprocess_create_exec from event loops in # non-main threads we have to call asyncio.get_child_watcher(). This # is only for Python 3.7 if ( sys.version_info.major == 3 and sys.version_info.minor == 7 and sys.platform != "win32" ): asyncio.get_child_watcher() # Create a new event loop loop = asyncio.get_event_loop() # In Python 3.8 ThreadedChildWatcher becomes the default which # should work fine for us. However, in Python 3.7 SafeChildWatcher # is the default and may cause BlockingIOErrors when many # subprocesses are created # https://docs.python.org/3/library/asyncio-policy.html#asyncio.FastChildWatcher if ( sys.version_info.major == 3 and sys.version_info.minor == 7 and sys.platform != "win32" ): watcher = asyncio.FastChildWatcher() asyncio.set_child_watcher(watcher) watcher.attach_loop(loop) result = None try: result = loop.run_until_complete(cls._main(*argv[1:])) if ( result is not None and result is not DisplayHelp and result is not CMDOutputOverride and result != [CMDOutputOverride] ): json.dump( export_dict(result=result)["result"], sys.stdout, sort_keys=True, indent=4, separators=(",", ": "), cls=cls.JSONEncoder, ) print() except KeyboardInterrupt: # pragma: no cover pass # pragma: no cover loop.run_until_complete(loop.shutdown_asyncgens()) loop.close() def __call__(self): return asyncio.run(self.do_run()) @classmethod def args(cls, args, *above) -> Dict[str, Any]: """ For compatibility with scripts/docs.py. Nothing else at the moment so if it doesn't work with other things that's why. """ return args