"""
Running of shell commands
"""
import os
import abc
import sys
import json
import time
import copy
import fcntl
import shlex
import signal
import atexit
import asyncio
import pathlib
import inspect
import tempfile
import platform
import functools
import contextlib
import subprocess
import http.server
from typing import IO, Any, Dict, List, Union, Optional
import httptest
from .... import plugins
DFFML_ROOT = pathlib.Path(__file__).parents[4]
[docs]class ConsoletestCommand(abc.ABC):
def __init__(self):
self.poll_until = False
self.compare_output = None
self.compare_output_imports = None
self.ignore_errors = False
self.daemon = None
def __repr__(self):
return (
self.__class__.__qualname__
+ "("
+ str(
{
k: v
for k, v in self.__dict__.items()
if not k.startswith("_")
}
)
+ ")"
)
def str(self):
return repr(self)
async def __aenter__(self):
return self
async def __aexit__(self, _exc_type, _exc_value, _traceback):
pass
[docs]class CDCommand(ConsoletestCommand):
def __init__(self, directory: str):
super().__init__()
self.directory = directory
def __eq__(self, other: "CDCommand"):
return bool(
hasattr(other, "directory") and self.directory == other.directory
)
async def run(self, ctx):
ctx["cwd"] = os.path.abspath(os.path.join(ctx["cwd"], self.directory))
[docs]class ActivateVirtualEnvCommand(ConsoletestCommand):
def __init__(self, directory: str):
super().__init__()
self.directory = directory
self.old_virtual_env = None
self.old_virtual_env_dir = None
self.old_path = None
self.old_pythonpath = None
self.old_sys_path = []
def __eq__(self, other: "ActivateVirtualEnvCommand"):
return bool(
hasattr(other, "directory") and self.directory == other.directory
)
async def run(self, ctx):
tempdir = ctx["stack"].enter_context(tempfile.TemporaryDirectory())
self.old_virtual_env = os.environ.get("VIRTUAL_ENV", None)
self.old_virtual_env_dir = os.environ.get("VIRTUAL_ENV_DIR", None)
self.old_path = os.environ.get("PATH", None)
self.old_pythonpath = os.environ.get("PYTHONPATH", None)
env_path = os.path.abspath(os.path.join(ctx["cwd"], self.directory))
os.environ["PATH"] = ":".join(
[os.path.abspath(tempdir), os.path.join(env_path, "bin")]
+ os.environ.get("PATH", "").split(":")
)
os.environ["PYTHONPATH"] = ":".join(
os.environ.get("PYTHONPATH", "").split(":")
+ [
os.path.join(
env_path,
"lib",
f"python{sys.version_info.major}.{sys.version_info.minor}",
"site-packages",
)
],
)
# conda
if "CONDA_PREFIX" in os.environ:
print("CONDA", env_path)
# Bump all prefixes up
for key, value in filter(
lambda i: i[0].startswith("CONDA_PREFIX_"),
list(os.environ.items()),
):
prefix = int(key[len("CONDA_PREFIX_") :])
os.environ[f"CONDA_PREFIX_{prefix + 1}"] = value
# Add new prefix
old_shlvl = int(os.environ["CONDA_SHLVL"])
os.environ["CONDA_SHLVL"] = str(old_shlvl + 1)
os.environ["CONDA_PREFIX_1"] = os.environ["CONDA_PREFIX"]
os.environ["CONDA_PREFIX"] = env_path
os.environ["CONDA_DEFAULT_ENV"] = env_path
else:
print("VIRTUAL_ENV", env_path)
os.environ["VIRTUAL_ENV"] = env_path
os.environ["VIRTUAL_ENV_DIR"] = env_path
for env_var in ["VIRTUAL_ENV", "CONDA_PREFIX"]:
if env_var in os.environ:
python_path = os.path.abspath(
os.path.join(os.environ[env_var], "bin", "python")
)
# Prepend a dffml command to the path to ensure the correct
# version of dffml always runs
# Write out the file
dffml_path = pathlib.Path(os.path.abspath(tempdir), "dffml")
dffml_path.write_text(
inspect.cleandoc(
f"""
#!{python_path}
import os
import sys
os.execv("{python_path}", ["{python_path}", "-m", "dffml", *sys.argv[1:]])
"""
)
)
dffml_path.chmod(0o755)
async def __aexit__(self, _exc_type, _exc_value, _traceback):
if self.old_virtual_env is not None:
os.environ["VIRTUAL_ENV"] = self.old_virtual_env
if self.old_virtual_env_dir is not None:
os.environ["VIRTUAL_ENV_DIR"] = self.old_virtual_env_dir
if self.old_path is not None:
os.environ["PATH"] = self.old_path
if self.old_pythonpath is not None:
os.environ["PYTHONPATH"] = self.old_pythonpath
# conda
if "CONDA_PREFIX" in os.environ:
# Decrement shell level
os.environ["CONDA_SHLVL"] = str(int(os.environ["CONDA_SHLVL"]) - 1)
if int(os.environ["CONDA_SHLVL"]) == 0:
del os.environ["CONDA_SHLVL"]
# Bump all prefixes down
for key, value in filter(
lambda i: i[0].startswith("CONDA_PREFIX_"),
list(os.environ.items()),
):
del os.environ[key]
prefix = int(key[len("CONDA_PREFIX_") :])
if prefix == 1:
lower_key = "CONDA_PREFIX"
os.environ["CONDA_PREFIX"] = value
os.environ["CONDA_DEFAULT_ENV"] = value
else:
os.environ[f"CONDA_PREFIX_{prefix - 1}"] = value
[docs]class HTTPServerCMDDoesNotHavePortFlag(Exception):
pass
async def run_dffml_command(cmd, ctx, kwargs):
# Run the DFFML command if its not the http server
if cmd[:4] != ["dffml", "service", "http", "server"]:
# Run the command
print()
print("Running", cmd)
print()
proc = subprocess.Popen(
cmd, start_new_session=True, cwd=ctx["cwd"], **kwargs
)
proc.cmd = cmd
else:
# Windows won't let two processes open a file at the same time
with tempfile.TemporaryDirectory() as tempdir:
# Ensure that the HTTP server is being started with an explicit port
if "-port" not in cmd:
raise HTTPServerCMDDoesNotHavePortFlag(cmd)
# Add logging
cmd.insert(cmd.index("server") + 1, "debug")
cmd.insert(cmd.index("server") + 1, "-log")
# Add the -portfile flag to make the server write out the bound port
# number
portfile_path = pathlib.Path(tempdir, "portfile.int").resolve()
cmd.insert(cmd.index("server") + 1, str(portfile_path))
cmd.insert(cmd.index("server") + 1, "-portfile")
# Save the port the command gave
ctx.setdefault("HTTP_SERVER", {})
given_port = cmd[cmd.index("-port") + 1]
ctx["HTTP_SERVER"][given_port] = 0
# Replace the port that was given with port 0 to bind on any free
# port
cmd[cmd.index("-port") + 1] = "0"
# Run the command
print()
print("Running", cmd)
print()
proc = subprocess.Popen(
cmd, start_new_session=True, cwd=ctx["cwd"], **kwargs
)
proc.cmd = cmd
# Read the file containing the port number
while proc.returncode is None:
if portfile_path.is_file():
port = int(portfile_path.read_text())
break
await asyncio.sleep(0.01)
# Map the port that was given to the port that was used
ctx["HTTP_SERVER"][given_port] = port
# Return the newly created process
return proc
[docs]@contextlib.contextmanager
def tmpenv(cmd: List[str]) -> List[str]:
"""
Handle temporary environment variables prepended to command
"""
oldvars = {}
tmpvars = {}
for var in cmd:
if "=" not in var:
break
cmd.pop(0)
key, value = var.split("=", maxsplit=1)
tmpvars[key] = value
if key in os.environ:
oldvars[key] = os.environ[key]
os.environ[key] = value
try:
yield cmd
finally:
for key in tmpvars.keys():
del os.environ[key]
for key, value in oldvars.items():
os.environ[key] = value
async def run_commands(
cmds,
ctx,
*,
stdin: Union[IO] = None,
stdout: Union[IO] = None,
ignore_errors: bool = False,
daemon: bool = False,
):
proc = None
procs = []
cmds = list(map(sub_env_vars, cmds))
for i, cmd in enumerate(cmds):
# Keyword arguments for Popen
kwargs = {}
# Set stdout to system stdout so it doesn't go to the pty
kwargs["stdout"] = stdout if stdout is not None else sys.stdout
# Check if there is a previous command
kwargs["stdin"] = stdin if stdin is not None else subprocess.DEVNULL
if i != 0:
# NOTE asyncio.create_subprocess_exec doesn't work for piping output
# from one process to the next. It will complain about stdin not
# having a fileno()
kwargs["stdin"] = proc.stdout
# Check if there is a next command
if i + 1 < len(cmds):
kwargs["stdout"] = subprocess.PIPE
# Check if we redirect stderr to stdout
if "2>&1" in cmd:
kwargs["stderr"] = subprocess.STDOUT
cmd.remove("2>&1")
# If not in venv ensure correct Python
if (
"VIRTUAL_ENV" not in os.environ
and "CONDA_PREFIX" not in os.environ
and cmd[0].startswith("python")
):
cmd[0] = sys.executable
# Handle temporary environment variables prepended to command
with tmpenv(cmd) as cmd:
# Run the command
if cmd[0] == "dffml":
# Run dffml command through Python so that we capture coverage info
proc = await run_dffml_command(cmd, ctx, kwargs)
else:
# Run the command
print()
print("Running", cmd)
print()
proc = subprocess.Popen(
cmd, start_new_session=True, cwd=ctx["cwd"], **kwargs
)
proc.cmd = cmd
procs.append(proc)
# Parent (this Python process) close stdout of previous command so that
# the command we just created has exclusive access to the output.
if i != 0:
kwargs["stdin"].close()
# Wait for all processes to complete
errors = []
for i, proc in enumerate(procs):
# Do not wait for last process to complete if running in daemon mode
if daemon and (i + 1) == len(procs):
break
proc.wait()
if proc.returncode != 0:
errors.append(f"Failed to run: {proc.cmd!r}: {proc.returncode}")
if errors and not ignore_errors:
raise RuntimeError("\n".join(errors))
if daemon:
return procs[-1]
def sub_env_vars(cmd):
for env_var_name, env_var_value in os.environ.items():
for i, arg in enumerate(cmd):
for check in ["$" + env_var_name, "${" + env_var_name + "}"]:
if check in arg:
cmd[i] = arg.replace(check, env_var_value)
return cmd
def pipes(cmd):
if not "|" in cmd:
return [cmd]
cmds = []
j = 0
for i, arg in enumerate(cmd):
if arg == "|":
cmds.append(cmd[j:i])
j = i + 1
cmds.append(cmd[j:])
return cmds
async def stop_daemon(proc):
if platform.system() != "Windows":
# Kill the whole process group (for problematic processes)
os.killpg(proc.pid, signal.SIGINT)
proc.send_signal(signal.SIGINT)
proc.wait()
[docs]class OutputComparisionError(Exception):
"""
Raised when the output of a command was incorrect
"""
[docs]@contextlib.contextmanager
def buf_to_fileobj(buf: Union[str, bytes]):
"""
Given a buffer, create a temporary file and write the contents of the string
of bytes buffer to the file. Seek to the beginning of the file. Yield the
file object.
"""
if isinstance(buf, str):
buf = buf.encode()
with tempfile.TemporaryFile() as fileobj:
fileobj.write(buf)
fileobj.seek(0)
yield fileobj
[docs]class ConsoleCommand(ConsoletestCommand):
def __init__(self, cmd: List[str]):
super().__init__()
self.cmd = cmd
self.daemon_proc = None
self.replace = None
self.stdin = None
self.stdin_fileobj = None
self.stack = contextlib.ExitStack()
async def run(self, ctx):
if self.daemon is not None and self.daemon in ctx["daemons"]:
await stop_daemon(ctx["daemons"][self.daemon].daemon_proc)
if self.compare_output is None:
with contextlib.ExitStack() as stack:
self.daemon_proc = await run_commands(
pipes(self.cmd),
ctx,
stdin=None
if self.stdin is None
else stack.enter_context(buf_to_fileobj(self.stdin)),
ignore_errors=self.ignore_errors,
daemon=bool(self.daemon),
)
if self.daemon is not None:
ctx["daemons"][self.daemon] = self
else:
while True:
with contextlib.ExitStack() as stack:
stdout = stack.enter_context(tempfile.TemporaryFile())
await run_commands(
pipes(self.cmd),
ctx,
stdin=None
if self.stdin is None
else stack.enter_context(buf_to_fileobj(self.stdin)),
stdout=stdout,
ignore_errors=self.ignore_errors,
)
stdout.seek(0)
stdout = stdout.read()
if call_compare_output(
self.compare_output,
stdout,
imports=self.compare_output_imports,
):
return
if not self.poll_until:
raise OutputComparisionError(
f"{self.cmd}: {self.compare_output}: {stdout.decode()}"
)
time.sleep(0.1)
async def __aenter__(self):
self.stack.__enter__()
return self
async def __aexit__(self, _exc_type, _exc_value, _traceback):
if self.daemon_proc is not None:
await stop_daemon(self.daemon_proc)
self.stack.__exit__(None, None, None)
[docs]class CreateVirtualEnvCommand(ConsoleCommand):
def __init__(self, directory: str):
super().__init__([])
self.directory = directory
def __eq__(self, other: "CreateVirtualEnvCommand"):
return bool(
hasattr(other, "directory") and self.directory == other.directory
)
async def run(self, ctx):
if "CONDA_PREFIX" in os.environ:
self.cmd = [
"conda",
"create",
f"python={sys.version_info.major}.{sys.version_info.minor}",
"-y",
"-p",
self.directory,
]
else:
self.cmd = ["python", "-m", "venv", self.directory]
await super().run(ctx)
[docs]class PipNotRunAsModule(Exception):
"""
Raised when a pip install command was not prefixed with python -m to run pip
as a module. Pip sometimes complains when this is not done.
"""
[docs]class PipInstallCommand(ConsoleCommand):
def __init__(self, cmd: List[str]):
super().__init__(cmd)
self.directories: List[str] = []
# Ensure that we are running pip using it's module invocation
if tuple(self.cmd[:2]) not in (("python", "-m"), ("python3", "-m")):
raise PipNotRunAsModule(cmd)
[docs] def fix_dffml_packages(self, ctx):
"""
If a piece of the documentation says to install dffml or one of the
packages, we need to make sure that the version from the current branch
gets installed instead, since we don't want to test the released
version, we want to test the version of the codebase as it is.
"""
package_names_to_directory = copy.copy(
plugins.PACKAGE_NAMES_TO_DIRECTORY
)
package_names_to_directory["dffml"] = "."
for i, pkg in enumerate(self.cmd):
if "[" in pkg and "]" in pkg:
for package_name in package_names_to_directory.keys():
if pkg.startswith(package_name + "["):
pkg, extras = pkg.split("[", maxsplit=1)
directory = package_names_to_directory[pkg]
directory = os.path.join(DFFML_ROOT, *directory)
directory = os.path.abspath(directory)
self.cmd[i] = directory + "[" + extras
if self.cmd[i - 1] != "-e":
self.cmd.insert(i, "-e")
self.directories.append(directory)
elif pkg in package_names_to_directory:
directory = package_names_to_directory[pkg]
directory = os.path.join(DFFML_ROOT, *directory)
directory = os.path.abspath(directory)
self.cmd[i] = directory
if self.cmd[i - 1] != "-e":
self.cmd.insert(i, "-e")
self.directories.append(directory)
async def run(self, ctx):
# In case a replace command changed something
self.fix_dffml_packages(ctx)
await super().run(ctx)
async def __aexit__(self, _exc_type, _exc_value, _traceback):
return
[docs]class NPMInstallCommand(ConsoleCommand):
async def run(self, ctx):
await super().run(ctx)
if platform.system() != "Windows":
flags = fcntl.fcntl(sys.stdout, fcntl.F_GETFL)
fcntl.fcntl(sys.stdout, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
[docs]class YarnInstallCommand(NPMInstallCommand):
pass
[docs]class DockerRunCommand(ConsoleCommand):
def __init__(self, cmd: List[str]):
name, needs_removal, cmd = self.find_name(cmd)
super().__init__(cmd)
self.name = name
self.needs_removal = needs_removal
self.stopped = False
[docs] @staticmethod
def find_name(cmd):
"""
Find the name of the container we are starting (if starting as daemon)
"""
name = None
needs_removal = bool("--rm" not in cmd)
for i, arg in enumerate(cmd):
if arg.startswith("--name="):
name = arg[len("--name=") :]
elif arg == "--name" and (i + 1) < len(cmd):
name = cmd[i + 1]
return name, needs_removal, cmd
def cleanup(self):
if self.name and not self.stopped:
subprocess.check_call(["docker", "stop", self.name])
if self.needs_removal:
subprocess.check_call(["docker", "rm", self.name])
self.stopped = True
async def __aenter__(self):
atexit.register(self.cleanup)
return self
async def __aexit__(self, _exc_type, _exc_value, _traceback):
self.cleanup()
[docs]class SimpleHTTPServerCommand(ConsoleCommand):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ts = None
async def __aenter__(self) -> "OperationImplementationContext":
self.ts = None
async def run(self, ctx):
# Default the port to 8000
given_port = "8000"
# Grab port number if given
if self.cmd[-1].isdigit():
given_port = self.cmd[-1]
if "--cgi" in self.cmd:
# Start CGI server if requseted
handler_class = http.server.CGIHTTPRequestHandler
else:
# Default to simple http server
handler_class = http.server.SimpleHTTPRequestHandler
# Specify directory if given
directory = ctx["cwd"]
if "--directory" in self.cmd:
directory = self.cmd[self.cmd.index("--directory") + 1]
# Ensure handler is relative to directory
handler_class = functools.partial(handler_class, directory=directory)
# Start a server with a random port
self.ts = httptest.Server(handler_class).__enter__()
# Map the port that was given to the port that was used
ctx.setdefault("HTTP_SERVER", {})
ctx["HTTP_SERVER"][given_port] = self.ts.server_port
return self
async def __aexit__(self, exc_type, exc_value, traceback):
self.ts.__exit__(None, None, None)
self.ts = None
def within_qoute(current, qoute=('"', "'")):
within = False
for i, char in enumerate(current):
context = current[i - 1 : i]
if char in qoute and not context.startswith("\\"):
within = not within
return within
def parse_commands(content):
commands = []
current = ""
for line in content:
line = line.rstrip()
if line.startswith("$ "):
if line.endswith("\\"):
current = line[2:-1]
else:
current = line[2:]
if within_qoute(current):
continue
commands.append(current)
current = ""
elif current and line.endswith("\\"):
current += line[:-1]
elif current and not line.endswith("\\"):
current += line
if within_qoute(current):
continue
commands.append(current)
current = ""
# Raise NotImplementedError if command substitution is attempted
for command in commands:
for check in ("`", "$("):
index = 0
while index != -1:
index = command.find(check, index + 1)
if index == -1:
continue
if not within_qoute(command[:index], qoute=("'")):
raise NotImplementedError(
f"Command substitution was attempted: {command}"
)
try:
commands = list(map(shlex.split, commands))
except ValueError as error:
print(commands)
raise
return commands
def build_command(cmd):
if not cmd:
raise ValueError("Empty command")
# Handle virtualenv creation
if (
"-m" in cmd and "venv" in cmd and cmd[cmd.index("-m") + 1] == "venv"
) or (cmd[:2] == ["conda", "create"]):
return CreateVirtualEnvCommand(cmd[-1])
# Handle virtualenv activation
if ".\\.venv\\Scripts\\activate" in cmd or (
len(cmd) == 2
and cmd[0] in ("source", ".")
and ".venv/bin/activate" == cmd[1]
):
return ActivateVirtualEnvCommand(".venv")
# Handle cd
if "cd" == cmd[0]:
return CDCommand(cmd[1])
# Handle pip installs
if (
"pip" in cmd
and "install" in cmd
and cmd[cmd.index("pip") + 1] == "install"
):
return PipInstallCommand(cmd)
# Handle yarn and npm install command
if cmd[:2] == ["npm", "install"]:
return NPMInstallCommand(cmd)
if cmd[:2] == ["yarn", "install"]:
return YarnInstallCommand(cmd)
# Handle simple http server
if cmd[1:3] == ["-m", "http.server"]:
return SimpleHTTPServerCommand(cmd)
# Regular console command
return ConsoleCommand(cmd)
MAKE_POLL_UNTIL_TEMPLATE = """
import sys
{imports}
func = lambda stdout: {func}
sys.exit(int(not func(sys.stdin.buffer.read())))
"""
def call_compare_output(func, stdout, *, imports: Optional[str] = None):
with tempfile.NamedTemporaryFile() as fileobj, tempfile.NamedTemporaryFile() as stdin:
fileobj.write(
MAKE_POLL_UNTIL_TEMPLATE.format(
func=func,
imports="" if imports is None else "import " + imports,
).encode()
)
fileobj.seek(0)
stdin.write(stdout.encode() if isinstance(stdout, str) else stdout)
stdin.seek(0)
return_code = subprocess.call(["python", fileobj.name], stdin=stdin)
return bool(return_code == 0)
MAKE_REPLACE_UNTIL_TEMPLATE = """
import sys
import json
import pathlib
cmds = json.loads(pathlib.Path(sys.argv[1]).read_text())
ctx = json.loads(pathlib.Path(sys.argv[2]).read_text())
{func}
print(json.dumps(cmds))
"""
def call_replace(
func: str, cmds: List[List[str]], ctx: Dict[str, Any]
) -> List[List[str]]:
with contextlib.ExitStack() as stack:
# Write out Python script
python_fileobj = stack.enter_context(tempfile.NamedTemporaryFile())
python_fileobj.write(
MAKE_REPLACE_UNTIL_TEMPLATE.format(func=func).encode()
)
python_fileobj.seek(0)
# Write out command
cmd_fileobj = stack.enter_context(tempfile.NamedTemporaryFile())
cmd_fileobj.write(json.dumps(cmds).encode())
cmd_fileobj.seek(0)
# Write out context
ctx_fileobj = stack.enter_context(tempfile.NamedTemporaryFile())
ctx_serializable = ctx.copy()
for remove in list(ctx["no_serialize"]) + ["no_serialize"]:
if remove in ctx_serializable:
del ctx_serializable[remove]
ctx_fileobj.write(json.dumps(ctx_serializable).encode())
ctx_fileobj.seek(0)
# Python file modifies command and json.dumps result to stdout
return json.loads(
subprocess.check_output(
[
"python",
python_fileobj.name,
cmd_fileobj.name,
ctx_fileobj.name,
],
)
)