#!/usr/bin/python3
#
# Copyright (c) 2024, Oracle and/or its affiliates.
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
#
# This code is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 2 only, as
# published by the Free Software Foundation.
#
# This code is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
# version 2 for more details (a copy is included in the LICENSE file that
# accompanied this code).
#
# You should have received a copy of the GNU General Public License version
# 2 along with this work; if not, see <https://www.gnu.org/licenses/>.
#
# Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
# or visit www.oracle.com if you need additional information or have any
# questions.

###############################################################################
#
# NAME: kstack
#
# AUTHORS: 2020 Cesar Roque
#          2023 Jose Lombera
#
# DESCRIPTION: Collect the kernel stack trace for selected processes.
#
# HISTORY:
#
# Fri Jul 24 10:45:00 PDT 2020 - Initial version.
#
############################################################
"""oled-kstack"""

import argparse
import fcntl
import glob
import gzip
import logging
import math
import os
import shutil
import socket
import sys
import syslog
import time
from datetime import datetime, timedelta
from typing import (
    Any, Dict, Generator, Iterable, List, Sequence, TextIO, Tuple)


def fail(msg: str) -> None:
    """Log error message and exit."""
    logging.error(msg)
    sys.exit(1)


def rotate_file(
        path: str,
        max_size_bytes: int,
        max_file_count: int,
        force: bool = False) -> None:
    """Rotate a file in disk.

    Archive file `path` if it has exceeded `max_size_bytes` in size or `force`
    is True.  At most `max_file_count` archive files are kept, older ones are
    removed.

    Skip if the file doesn't exist or `max_size_bytes` <= 0 or
    max_file_count` <= 0.
    """
    if not os.path.isfile(path) or max_size_bytes <= 0 or max_file_count <= 0:
        return

    if not force and os.stat(path).st_size <= max_size_bytes:
        return

    # rotate and compress file
    with open(path, "rb") as in_fd:
        with gzip.open(f"{path}-{int(time.time())}.gz", mode="wb") as out_fd:
            shutil.copyfileobj(in_fd, out_fd)

    os.remove(path)

    # remove old archive files
    for file in sorted(
            glob.glob(f"{path}-*.gz"), reverse=True)[max_file_count:]:
        try:
            os.remove(file)
        except FileNotFoundError:
            pass


def setup_logging(verbose: bool) -> None:
    """Setup application logging."""
    logging.basicConfig(
        format="%(asctime)s.%(msecs)d %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.DEBUG if verbose else logging.INFO)

    # set identity for syslog messages produced by this process
    syslog.openlog("oled-kstack")


def collect_procs_info(
        all_procs: bool,
        pids: Sequence[int],
        states: Sequence[str]
) -> Generator[Tuple[int, str, str, int, str], None, None]:
    """Collect info about processes matching criteria.

    Returns a generator with following info about processes matching criteria:
      (PID, state, cmdline, epoch_start_time, stack)
    where
      PID: the process ID
      state: the process state (e.g. "R", "Z", "D", ...)
      cmdline: /proc/<pid>/cmdline (or /proc/<pid>/comm if the former is
               missing; e.g. for kernel threads)
      start_time: the process start time in epoch timestamp
      stack: /proc/<pid>/stack (empty if an error occurred to read it)

    MATCHING CRITERIA

    If all_procs is True, all processes are considered; otherwise only
    processes whose PID is in pids or it's state is in states are considered
    (see /proc/[pid]/stat in procfs(5) for details of the possible states).
    """
    # Retrieve CPU's ticks per second and system boot time to compute
    # processes' boot time in epoch timestamp.
    hertz = os.sysconf("SC_CLK_TCK")
    try:
        with open("/proc/stat") as fdesc:
            for line in fdesc:
                if line.startswith("btime"):
                    boot_time = int(line.split()[1])
                    break
            else:
                fail("Boot time (btime) not found in /proc/stat.")
    except Exception as exn:  # pylint: disable=broad-except
        fail(f"Unable to read boot time from /proc/stat: {exn}")

    if pids and not all_procs and not states:
        # When pids is the only selection criteria, avoid walking all the PIDs
        # in the system.
        collect_pids: Iterable[int] = pids
    else:
        collect_pids = (
            int(pid) for pid in os.listdir("/proc") if pid.isdigit()
        )

    for pid in collect_pids:
        try:
            with open(f"/proc/{pid}/stat") as fdesc:
                stat_info = fdesc.readline().strip().split()

            state = stat_info[2]

            if not (all_procs or (pid in pids) or (state in states)):
                continue  # this process does not match selection criteria

            epoch_start_time = (int(stat_info[21]) // hertz) + boot_time

            # Try reading /proc/<pid>/cmdline.  Fall back to process's comm
            # from /proc/<pid>/stat if the former is empty or there was an
            # error reading it.
            try:
                with open(f"/proc/{pid}/cmdline") as fdesc:
                    cmdline = fdesc.read().strip().replace("\0", " ")
            except IOError:
                cmdline = None

            if not cmdline:
                cmdline = stat_info[1][1:-1]  # remove surrounding parenthesis

            try:
                with open(f"/proc/{pid}/stack") as fdesc:
                    stack = fdesc.read().strip()
            except IOError:
                stack = ""

            yield (pid, state, cmdline, epoch_start_time, stack)

        except Exception:  # pylint: disable=broad-except
            pass  # skip this PID on error #nosec


def write_proc_infos(
        output_fd: TextIO,
        infos: Iterable[Tuple[int, str, str, int, str]]) -> None:
    """Write processes info to the given file descriptor.

    infos is a value as returned by collect_procs_info().
    """

    curr_date = time.localtime()

    stacks: Dict[str, List[Any]] = {}

    # group all processes with same stack
    for *vals, stack in infos:
        stacks.setdefault(stack, []).append(vals)

    curr_date_str = time.strftime("%a %b %d %H:%M:%S %Z %Y", curr_date)
    output_fd.write(f"zzz <{curr_date_str} - {socket.gethostname()}\n\n")

    fmt = "PID: {pid:>5}   " \
          "State: {state:>2s}   " \
          "Start Time: {start:%Y/%m/%d %H:%M:%S}   " \
          "Cmd: {cmd}\n"

    for stack, procs in stacks.items():
        for pid, state, cmdline, epoch_start_time in procs:
            output_fd.write(fmt.format(
                pid=pid, state=state,
                start=datetime.fromtimestamp(epoch_start_time),
                cmd=cmdline))

        output_fd.write(f"Stack:\n{stack}\n\n")


def check_space(path: str, usage_percent_threshold: int) -> None:
    """Check FS hasn't reached usage threshold."""
    logging.debug("Checking free space on filesystem: %s", path)

    if os.path.exists(path):
        # compute FS percentage used space
        stat = os.statvfs(path)
        used_space = math.ceil(
            (stat.f_blocks - stat.f_bfree) * 100 / stat.f_blocks)

        if used_space >= usage_percent_threshold:
            fail(f"Maximum filesystem space reached. Used: {used_space}%, "
                 f"Threshold: {usage_percent_threshold}%")
        else:
            logging.debug(
                "Used: %d%%, Threshold: %d%%",
                used_space, usage_percent_threshold)
    else:
        fail(f"Path doesn't exist, unable to get filesystem usage: {path}")


def demonize() -> None:
    """Demonize the process."""
    logging.info("Spawning process to run in background...")
    pid = os.fork()

    if pid == 0:
        pid2 = os.fork()

        if pid2 != 0:
            logging.info("Background PID: %d", pid2)
            sys.exit(0)
    else:
        os.waitpid(pid, 0)
        sys.exit(0)


DEFAULT_DIR = "/var/oled/kstack"


def parse_args(args: Sequence[str]) -> argparse.Namespace:
    """Parse CLI arguments."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        "-v", "--verbose", dest="verbose", action='store_true',
        help="Display debugging data.")

    process_group = parser.add_argument_group(
        "SELECTION OPTIONS",
        "(At least one of this options must be specified.)")
    process_group.add_argument(
        "-a", dest="all", action='store_true', help="All processes.")
    process_group.add_argument(
        "-D", dest="states", action='append_const', const="D",
        help="Processes waiting on I/O or a lock.")
    process_group.add_argument(
        "-R", dest="states", action='append_const', const="R",
        help="Running processes.")
    process_group.add_argument(
        "-S", dest="states", action='append_const', const="S",
        help="Sleeping processes.")
    process_group.add_argument(
        "-Z", dest="states", action='append_const', const="Z",
        help="Zombie processes.")
    process_group.add_argument(
        "-p", dest="pids", metavar="PID[,...]",
        type=(lambda vals: tuple(int(x) for x in vals.split(","))), default=(),
        help=("Dump stack traces for PID.  More than one PID can be specified"
              " if separated by commas."))

    background_group = parser.add_argument_group("BACKGROUND OPTIONS")
    background_group.add_argument(
        "-b", dest="background", action='store_true',
        help=("Run in background mode.  Data will be written to "
              f"{DEFAULT_DIR}/kstack_HOSTNAME.out or the directory selected on"
              " -d option."))
    background_group.add_argument(
        "-t", dest="minutes", type=int, default=30,
        help="Specify the number of MINUTES to run when in background mode.")
    background_group.add_argument(
        "-i", dest="sleep", metavar="SECONDS", type=int, default=60,
        help="Specify the number of SECONDS between samples.")

    log_group = parser.add_argument_group("LOG FILE OPTIONS")
    log_group.add_argument(
        "-d", dest="directory_name", metavar="DIRECTORY", type=str,
        default=DEFAULT_DIR, help="The target directory to write log files.")
    log_group.add_argument(
        "-m", dest="max_size_mb", type=int, default=1,
        help=("The maximum size, in megabytes, for the log file.  Once the log"
              " file exceeds this size it will be rotated and compressed."))
    log_group.add_argument(
        "-n", dest="files_number", metavar="NUMBER_OF_FILES", type=int,
        default=5, help="The number of rotated files to retain.")
    log_group.add_argument(
        "-x", dest="max_space", type=int, default=85,
        help=("This is the maximum percent used on the target file system.  "
              "If the target file system is at or above this limit, then the "
              "program will refuse to run."))

    options = parser.parse_args(args)

    if not options.states and not options.pids and not options.all:
        parser.error(
            "At least one of the -a/-D/-I/-R/-S/-Z/-p must be specified.")

    if options.background and options.sleep < 10:
        parser.error("Option -i must be >= 10 seconds.")

    return options


def main(args: Sequence[str]) -> None:
    """Main function."""

    options = parse_args(args)

    if os.getuid() != 0:
        fail("This script must be run as root")

    setup_logging(options.verbose)

    # operate under FS lock
    lock_path = "/run/lock/oled-kstack.lock"
    with open(lock_path, "w") as lock_fd:
        # the lock will be removed automatically when the file is closed
        try:
            fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
        except IOError as exn:
            fail(f"Error while locking '{lock_path}', make sure no other "
                 f"instance of kstack is running: {exn}")

        if not options.background:
            # foreground run; dump stacks once and exit
            write_proc_infos(
                sys.stdout,
                collect_procs_info(options.all, options.pids, options.states))
            return

        #
        # background logic follows
        #

        os.makedirs(options.directory_name, exist_ok=True)
        check_space(options.directory_name, options.max_space)
        demonize()

        # Setting the time/date to finish the run
        limit = datetime.now() + timedelta(minutes=options.minutes)
        logging.info("This script will run up to: %s", limit)
        limit += timedelta(seconds=15)

        logging.info(
            "The requested data will be written to the directory: %s",
            options.directory_name)

        output_path = os.path.realpath(os.path.join(
            options.directory_name,
            f"kstack_{socket.gethostname()}.out"))

        # Force output file rotation before data collection starts to not mix
        # output from previous runs.
        rotate_file(
            output_path, max_size_bytes=(options.max_size_mb * 2**20),
            max_file_count=options.files_number, force=True)

        while datetime.now() < limit:
            rotate_file(
                output_path, max_size_bytes=(options.max_size_mb * 2**20),
                max_file_count=options.files_number)

            procs_infos = collect_procs_info(
                options.all, options.pids, options.states)

            with open(output_path, "a") as output_fd:
                write_proc_infos(output_fd, procs_infos)

            time.sleep(options.sleep)


if __name__ == '__main__':
    main(sys.argv[1:])
