#!/usr/bin/python3

# Copyright © 2024, Oracle and/or its affiliates.  All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import argparse
import fcntl
import os
import re
import subprocess
import sys
import time

if os.path.exists('/etc/oci-fss-utils.d/prefix.txt') is False:
    sys.exit(0)

with open('/etc/oci-fss-utils.d/prefix.txt', encoding="utf-8") as fp:
    PREFIX = fp.readline().strip()

if re.match("^[a-zA-Z0-9:]+::/64$", PREFIX):
    PROTOCOL = "ipv6"
elif re.match("^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.0/24$", PREFIX):
    PROTOCOL = "ipv4"
else:
    raise ValueError("Invalid prefix in /etc/oci-fss-utils.d/prefix.txt.")

if PROTOCOL == "ipv6":
    PREFIX = PREFIX.replace("::/64", "")
    MOUNT_PATTERN = "^\[%s::([0-9a-z]{1,4})\]" % PREFIX
elif PROTOCOL == "ipv4":
    PREFIX = PREFIX.replace(".0/24", "")
    MOUNT_PATTERN = "^%s.([0-9]{1,3})" % PREFIX
else:
    raise ValueError("Invalid prefix in /etc/oci-fss-utils.d/prefix.txt.")

MOUNT_PATTERN_LEGACY = r"^192\.168\.([0-9]+)\.2:/"


class MountLock:
    def __init__(self):
        self.fd = open("/run/oci-fss-mount.lck", "w+")

    def __del__(self):
        fcntl.flock(self.fd, fcntl.LOCK_UN)

    def acquire(self):
        fcntl.flock(self.fd, fcntl.LOCK_EX)


def run_cmd(cmd):
    print(cmd)
    return subprocess.run(cmd.split(' '))


def find_mounts_in_mounts_file(fp, re_mount, re_legacy, mounts, mounts_legacy):
    for line in fp:
        words = line.strip().split(" ")
        m = re_mount.match(words[0])
        if m and PROTOCOL == "ipv6":
            mounts.add(int(m.group(1), 16))
            continue
        if m and PROTOCOL == "ipv4":
            mounts.add(int(m.group(1)))
            continue
        m = re_legacy.match(words[0])
        if m:
            mounts_legacy.add(int(m.group(1)))
            continue


def dump_mounts():
    mounts = set()
    mounts_legacy = set()

    re_pid = re.compile("^[0-9]+$")
    re_mount = re.compile(MOUNT_PATTERN)
    re_mount_legacy = re.compile(MOUNT_PATTERN_LEGACY)

    # Check every process because each process may be in its own mount namespace...
    for p in os.listdir("/proc"):
        if re_pid.match(p):
            try:
                with open(f"/proc/{p}/mounts", encoding="utf-8") as fp:
                    find_mounts_in_mounts_file(
                        fp, re_mount, re_mount_legacy, mounts, mounts_legacy)
            except (FileNotFoundError, OSError):
                pass

    # Look in root mount namespace...
    with open("/proc/mounts", encoding="utf-8") as fp:
        find_mounts_in_mounts_file(
            fp, re_mount, re_mount_legacy, mounts, mounts_legacy)

    return mounts, mounts_legacy


def do_gc():
    lck = MountLock()
    lck.acquire()

    mounts, mounts_legacy = dump_mounts()
    gc_count = 0
    live_count = 0

    # Mount targets follow this scheme: oci-fss-{i:04}
    for i in range(1, 255):
        if i in mounts:
            print(f"slot {i} has an associated mount...")
            live_count += 1
        elif i in mounts_legacy:
            print(f"slot {i} has an associated legacy mount...")
            live_count += 1
        else:
            slotfile = f"/run/oci-fss-utils.d/slot-{i}.txt"
            if os.path.isfile(slotfile):
                with open(slotfile, "r", encoding="utf-8") as slotfp:
                    mt = slotfp.readlines()[0].strip()

                mt_file = f"/run/oci-fss-utils.d/mt-{mt}.txt"
                if os.path.isfile(mt_file):
                    with open(mt_file, "r", encoding="utf-8") as mtfp:
                        addr = mtfp.readlines()[0].strip()
                    addr_file = f"/run/oci-fss-utils.d/addr-{addr}.json"
                    if os.path.isfile(addr_file):
                        os.unlink(f"/run/oci-fss-utils.d/addr-{addr}.json")
                    os.unlink(mt_file)
                os.unlink(slotfile)

                print(f"removed mount on slot {i:04}...")
                gc_count += 1

    print(f"live count: {live_count}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='oci-fss-gc',
        usage="oci-fss-gc [options]",
        formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('-d', '--daemon', required=False, action='store_true')
    args = parser.parse_args()

    while args.daemon:
        time.sleep(300)
        do_gc()
    do_gc()
