#!/usr/bin/python3
#
# Copyright (c) 2023-2024, Oracle and/or its affiliates.
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
#
# This code is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 2 only, as
# published by the Free Software Foundation.
#
# This code is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
# version 2 for more details (a copy is included in the LICENSE file that
# accompanied this code).
#
# You should have received a copy of the GNU General Public License version
# 2 along with this work; if not, see <https://www.gnu.org/licenses/>.
#
# Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
# or visit www.oracle.com if you need additional information or have any
# questions.
#
# Author: Jose Lombera <jose.lombera@oracle.com>

"""
Tool to watch CPU utilization and execute user provided commands when
configured thresholds are reached.
"""
import argparse
import datetime
import logging
import logging.handlers
import os
import shlex
import signal
import subprocess  # nosec
import sys
import time
from collections import deque

from enum import Enum
from typing import Dict, List, Mapping, Sequence

# A CPU stats snapshot
#
# {
#   "cpu": [s1, s2, ...],
#   "cpu0": [...],
#    ...
#   "cpuN": [...]
# }
Snapshot = Dict[str, List[int]]


def take_snapshot() -> Snapshot:
    """Take CPU stats snapshot from /proc/stat."""
    with open("/proc/stat", "r") as desc:
        # Read file, split content in lines, drop lines that don't start with
        # "cpu", split the remaining lines in tokens separated by spaces.  See
        # procfs(5) for a description of the format of the file.
        data = (
            line.split() for line in desc.readlines() if line.startswith("cpu")
        )

        # Return mapping { "all": [s1, s2, ...], "0": [s1, s2, ...], ... }
        # where s1, s2, ... are integers.
        return {
            "all" if l[0] == "cpu" else l[0][3:]: [int(x) for x in l[1:]]
            for l in data
        }


def compute_percent_util(ticks1: List[int], ticks2: List[int]) -> List[float]:
    """Compute CPU utilization (in %) from two CPU stats snapshots"""
    elapsed_ticks = sum(ticks2) - sum(ticks1)

    return [
        ((ticks2[i] - ticks1[i]) * 100) / elapsed_ticks
        for i in range(len(ticks1))
    ]


CpuUtil = Dict[str, List[float]]


def snap_percent_util(snap1: Snapshot, snap2: Snapshot) -> CpuUtil:
    """Compute CPU utilization (in %) between to snapshots."""
    return {
        k: compute_percent_util(snap1[k], snap2[k])
        for k in snap1
        if k in snap2  # only consider CPUs that appear in both snapshots
    }


class CpuStat(Enum):
    """CPU stat IDs"""
    usr = 0
    nice = 1
    sys = 2
    idle = 3
    iowait = 4
    irq = 5
    soft = 6
    steal = 7
    guest = 8
    gnice = 9


stat_fire_high = [CpuStat.usr, CpuStat.nice, CpuStat.sys,
                  CpuStat.iowait, CpuStat.irq, CpuStat.soft,
                  CpuStat.steal, CpuStat.guest, CpuStat.gnice]
stat_fire_low = [CpuStat.idle]


def cpu_thresholds_reached(
        cpu_utils: Sequence[Sequence[float]],
        stat_thresholds: Mapping[CpuStat, int],
        nr_cpus: int) -> bool:
    """Compute CPU stats that reached the specified thresholds"""
    for stats in cpu_utils:
        for stat, threshold in stat_thresholds.items():
            if stat.value in [v.value for v in stat_fire_low]:
                if stats[stat.value] > threshold:
                    # Don't fire if the value of the stat is greather than
                    # the threshold value.
                    break
            elif stat.value in [v.value for v in stat_fire_high]:
                if stats[stat.value] < threshold:
                    # Don't fire if the value of the stat is less than
                    # the threshold value.
                    break
            else:
                logging.error("CPU Stat %d is not in either list.",
                              stat.value)
                # Stat is not in either list, default to fire when the stat
                # is greather than or equal to the threshold value.
                if stats[stat.value] < threshold:
                    break
        else:
            nr_cpus = nr_cpus - 1  # current CPU reached all thresholds

    return nr_cpus <= 0


def stat_threshold_reached(
        cpu_utils: Sequence[Sequence[float]],
        stat: CpuStat,
        threshold: int,
        nr_cpus: int) -> bool:
    """Return True if @stat reaches @threshold on at least @nr_cpus CPUs."""
    count = 0
    for stats in cpu_utils:
        if stat.value in [v.value for v in stat_fire_low]:
            # Fire when the stat is <= threshold (e.g. idle%)
            if stats[stat.value] <= threshold:
                count += 1
        elif stat.value in [v.value for v in stat_fire_high]:
            # Fire when the stat is >= threshold (e.g. user%, sys%)
            if stats[stat.value] >= threshold:
                count += 1
        else:
            # Unknown stat group; default to >= threshold
            if stats[stat.value] >= threshold:
                count += 1

        if count >= nr_cpus:
            return True

    return False


def check_fs_util(dir_path: str, max_fs_util: int) -> None:
    """Check that there is enough space in a filesystem.

    Exit program if the filesystem containing dir_path surpasses max_fs_util
    disk usage.
    """
    fs_info = os.statvfs(dir_path)
    space_used = (fs_info.f_blocks - fs_info.f_bfree) * 100 / fs_info.f_blocks

    if space_used > max_fs_util:
        logging.error(
            ("Disk space usage in filesystem containing '%s' has surpassed "
             "the maximum allowed.  current=%.2f%%; max_allowed=%d%%\n"
             "Aborting"),
            dir_path, space_used, max_fs_util)
        sys.exit(1)


def exec_action_cmds(cmds: Sequence[str], base_dir: str) -> None:
    """Execute action commands.

    The commands are executed in parallel and this function doesn't return
    until all commands finish.
    Each command is executed in it's own directory, named after the command,
    replacing every space and slash (/) by two underscore (_).  E.g. command
    "echo hello" will be executed in directory "echo__hello/".  The command's
    stdout/stderr are redirected to file "output" in its CWD.

    NOTE: The commands' exit status is not checked/logged/processed.
    """
    logging.info("Executing action commands in directory '%s'...", base_dir)

    def exec_cmd(cmd):
        cmd_dir = os.path.join(
            base_dir, cmd.replace(" ", "__").replace("/", "__"))

        # Ensure we use a unique directory, just append underscores until we
        # form a unique directory name.  Existing directories can be found when
        # the user specifies repeated action commands.
        while os.path.isdir(cmd_dir):
            cmd_dir = cmd_dir + "_"

        os.makedirs(cmd_dir)

        with open(os.path.join(cmd_dir, "output"), "x") as log_fd:
            # We need to execute user-provided commands through a shell.  This
            # is subject to shell-injection, but this feature is the whole
            # point of syswatch.
            return subprocess.Popen(
                cmd, close_fds=True, stdout=log_fd, stderr=log_fd,
                stdin=subprocess.DEVNULL, cwd=cmd_dir, shell=True)  # nosec

    processes = [exec_cmd(c) for c in cmds]

    for proc in processes:
        proc.wait()

    logging.info("done executing actions")


def format_cpu_util(cpu_util: CpuUtil) -> str:
    """Format CPU utilization snapshot."""
    fmt = "{:>6} {:>6.2f} {:>6.2f} {:>6.2f} {:>7.2f} {:>6.2f} {:>6.2f} "\
          "{:>6.2f} {:>6.2f} {:>6.2f} {:>6.2f}"
    header = "   CPU   %usr  %nice   %sys %iowait   %irq  %soft %steal %guest"\
             " %gnice  %idle"

    stats = '\n'.join(
        fmt.format(
            cpu, st[CpuStat.usr.value], st[CpuStat.nice.value],
            st[CpuStat.sys.value], st[CpuStat.iowait.value],
            st[CpuStat.irq.value], st[CpuStat.soft.value],
            st[CpuStat.steal.value], st[CpuStat.guest.value],
            st[CpuStat.gnice.value], st[CpuStat.idle.value])
        for cpu, st in cpu_util.items()
    )

    return header + '\n' + stats


def watch_cpu_utilization(
        interval: int,
        intervals: int,
        stat_thresholds: Mapping[CpuStat, int],
        actions: Sequence[str],
        nr_cpus: int,
        working_dir: str,
        max_fs_util: int,
        only_one_event: bool) -> None:
    """Watch for CPU utilization.

    And execute given actions when given thresholds are reached.
    """
    # pylint: disable=too-many-arguments

    logging.info("CPU utilization watch started...")
    old_timestamp = datetime.datetime.now()
    old_snapshot = take_snapshot()
    history = None
    if intervals > 1:
        history = {stat: deque(maxlen=intervals) for stat in stat_thresholds}

    while True:
        time.sleep(interval)
        new_timestamp = datetime.datetime.now()
        new_snapshot = take_snapshot()

        # compute CPU utilization between snapshots
        snapshot_util = snap_percent_util(old_snapshot, new_snapshot)

        # When no specific # of CPUs is provided (i.e. nr_cpus <= 0), we watch
        # the cumulative stats of all CPUs in the system (stats for key "all"
        # in snapshot_util); otherwise we watch stats of all individual CPUs,
        # discarding the cumulative stats.
        if nr_cpus > 0:
            cpu_utils = tuple(
                stats for cpu, stats in snapshot_util.items() if cpu != "all")
        else:
            cpu_utils = (snapshot_util["all"], )

        # execute actions if CPU util thresholds were reached
        required_cpus = 1 if nr_cpus <= 0 else nr_cpus

        if intervals <= 1:
            should_fire = cpu_thresholds_reached(
                cpu_utils, stat_thresholds, required_cpus)
        else:
            # Track the last N samples per stat so we can require N consecutive
            # intervals.
            for stat, threshold in stat_thresholds.items():
                met = stat_threshold_reached(
                    cpu_utils, stat, threshold, required_cpus)
                history[stat].append(met)

            should_fire = all(
                len(history[stat]) == intervals and all(history[stat])
                for stat in stat_thresholds
            )

        if should_fire:
            # CPU utilization reached
            logging.info(
                ("Reached CPU utilization thresholds:\n"
                 "Snapshot: %s - %s\n"
                 "Stats:\n%s"),
                old_timestamp.strftime("%Y-%m-%dT%H:%M:%S"),
                new_timestamp.strftime("%Y-%m-%dT%H:%M:%S"),
                format_cpu_util(snapshot_util))

            check_fs_util(working_dir, max_fs_util)

            cmds_dir = os.path.join(
                working_dir, new_timestamp.strftime("%Y-%m-%dT%H-%M-%S"))
            exec_action_cmds(actions, cmds_dir)

            if only_one_event:
                break

            logging.info("continue watching...")

            # take new snapshot after action commands finished executing
            new_timestamp = datetime.datetime.now()
            new_snapshot = take_snapshot()

        old_timestamp = new_timestamp
        old_snapshot = new_snapshot

    logging.info("Stop watching")


def parse_args() -> argparse.Namespace:
    """Parse CLI arguments."""
    def percentage(value):
        """Validate a percentage value (must be >= 1 and <= 100)"""
        value = int(value)

        if value < 1 or value > 100:
            raise ValueError("Percentage must be in range [1, 100]")

        return value

    def interval(value):
        """Validate snapshot interval"""
        value = int(value)

        if value < 1:
            raise ValueError("Snapshot interval must be >= 1 seconds")

        return value

    def stat_threshold(value):
        """Validate stat threshold"""
        try:
            try:
                stat_str, percent = value.split(":", maxsplit=1)
            except ValueError:
                raise ValueError("must be of the form '<stat>:<percentage>'")

            try:
                stat = CpuStat[stat_str]
            except KeyError:
                raise ValueError(f"'{stat_str}' is not a valid CPU stat")

            percent = percentage(percent)

            return (stat, percent)
        except Exception as exp:
            raise argparse.ArgumentTypeError(
                f"Invalid threshold '{value}': {exp}")

    parser = argparse.ArgumentParser(
        description=(
            "Execute user specified commands if configured CPU utilization"
            " thresholds are reached.  See oled-syswatch(8) for a detailed "
            "description."),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "-b", dest="run_continuously", action='store_true',
        help="Run indefinitely until manually terminated.")
    parser.add_argument(
        "-C", dest="nr_cpus", type=int, default=0,
        help=("# of CPUs to apply match criteria.  A value <= 0 means apply "
              "system wide."))
    parser.add_argument(
        "-s", dest="stat_thresholds", type=stat_threshold, action="append",
        required=True, metavar="STAT:PERCENTAGE",
        help=("Triggers action commands if the specified CPU STAT utilization "
              "reaches or exceeds PERCENTAGE.  Valid values of STAT are: "
              f"{', '.join(x.name for x in CpuStat)}.  "
              "PERCENTAGE must be an integer in the range [1, 100].  This "
              "option is required and can be specified multiple times."))
    parser.add_argument(
        "-t", dest="target_dir", default="/var/oled/syswatch",
        help=("Output directory.  The program will cd to this directory before"
              " performing any actions."))
    parser.add_argument(
        "-M", dest="max_fs_util", type=percentage, default=85,
        help=("Max Filesystem Utilization.  If the filesystem is at or above "
              "the set %%, then exit and don't take any action."))
    parser.add_argument(
        "-I", dest="interval", type=interval, default=5,
        help="Interval, in seconds, between CPU utilization snapshots.")

    parser.add_argument(
        "-i", "--intervals", dest="intervals", type=int, default=1,
        help=("Require thresholds to be reached for N consecutive sampling "
              "intervals before firing."))
    parser.add_argument(
        "-c", dest="commands", metavar="COMMAND", required=True,
        action="append",
        help=("Command to execute when CPU stat thresholds are reached.  This"
              " is mandatory.  If a more complex command is needed this can "
              "point to a script.  This option can be specified multiple "
              "times, in which case all the commands will be executed in "
              "parallel."))

    args = parser.parse_args()

    if args.intervals < 1:
        parser.error("--intervals must be >= 1")

    # Convert util thresholds to a dictionary, this will remove duplicate
    # entries
    args.stat_thresholds = dict(args.stat_thresholds)

    if sum(args.stat_thresholds.values()) > 100:
        parser.error("The sum of CPU STAT percentages cannot be > 100%")

    return args


def setup_logging(working_dir: str) -> None:
    """Setup application logging."""
    formatter = logging.Formatter(
        fmt="%(asctime)s %(levelname)s - %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%S%z")

    log_file = os.path.join(working_dir, "syswatch.log")
    file_hdlr = logging.FileHandler(log_file, mode="w")
    file_hdlr.setFormatter(formatter)

    logger = logging.getLogger()

    logger.addHandler(file_hdlr)

    # also log to stdout if it's a tty
    if sys.stdout.isatty():
        stdout_hdlr = logging.StreamHandler(sys.stdout)
        stdout_hdlr.setFormatter(formatter)
        logger.addHandler(stdout_hdlr)

    logger.setLevel(logging.INFO)

    logging.info("Log file: %s", log_file)


def main() -> None:
    """Main function."""
    args = parse_args()

    if os.getuid() != 0:
        logging.error("This script must be run as root.  Aborting.")
        sys.exit(1)

    # create working dir and change to it
    timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    working_dir = os.path.join(
        args.target_dir, f"syswatch_{timestamp}_{os.getpid()}")
    os.makedirs(working_dir)
    os.chdir(working_dir)

    setup_logging(working_dir)
    logging.info(" ".join(map(shlex.quote, sys.argv)))

    thresholds = "\n".join(
        f"\t\t{k.name}: {v}%" for k, v in args.stat_thresholds.items())
    config_str = (
        f"\tHost Name: {os.uname()[1]}\n"
        f"\tRun continuously: {args.run_continuously}\n"
        f"\t# CPUs: {'ALL' if args.nr_cpus <= 0 else args.nr_cpus}\n"
        f"\tThresholds:\n{thresholds}\n"
        f"\tWorking dir: {working_dir}\n"
        f"\tMax FS utilization: {args.max_fs_util}%\n"
        f"\tSnapshot interval: {args.interval} sec\n"
        f"\tConsecutive Intervals Required: {args.intervals}\n"
        f"\tActions commands: {args.commands}"
    )
    logging.info("Config:\n %s", config_str)

    check_fs_util(args.target_dir, args.max_fs_util)

    watch_cpu_utilization(
        args.interval, args.intervals, args.stat_thresholds,
        args.commands, args.nr_cpus,
        working_dir, args.max_fs_util,
        only_one_event=(not args.run_continuously))


def exit_signal_handler(*_args, **_kwargs) -> None:
    """Signal handler that exits the program."""
    logging.error("Interrupted")
    sys.exit(1)


if __name__ == "__main__":
    # gracefully handle common termination signals
    signal.signal(signal.SIGINT, exit_signal_handler)
    signal.signal(signal.SIGTERM, exit_signal_handler)

    main()
    logging.info("Finished")
