#!/usr/bin/python3

# Copyright © 2024, Oracle and/or its affiliates.  All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import filecmp
import json
import logging
import os
import re
import subprocess as subp
import sys
import time
import requests
import fcntl

logging.basicConfig(level=logging.INFO)
logging.info("Starting monitor service")
logging.getLogger("urllib3").setLevel(logging.WARNING)

prev_openssl_bundle = ""
prev_imds_bundle = ""

class MountLock:
    def __init__(self):
        self.fd = open("/run/oci-fss-mount.lck", "w+", encoding="utf-8")

    def acquire(self):
        fcntl.flock(self.fd, fcntl.LOCK_EX)

    def release(self):
        fcntl.flock(self.fd, fcntl.LOCK_UN)

def get_imds_cert_bundle(url_cert_bundle):
    try:
        # if this is an oci instance, get the ca bundle from IMDS.
        r = requests.get(url_cert_bundle, headers = {"Authorization" : "Bearer Oracle"})
        if r.status_code == 200:
            return r.text
    except requests.exceptions.ConnectionError as err:
        logging.info(f'Unable to download {url_cert_bundle}.')
    except Exception as err:
        logging.exception(err)

    return ""

def update_latest_cert_bundle(url_cert_bundle, ca_file, temp_ca_file):
    openssl_dir = subp.check_output(['openssl', 'version', '-d'], universal_newlines=True)
    # the output is in format 'OPENSSLDIR: "/etc/pki/tls"' Hence this needs to be split
    openssl_dir = openssl_dir.split(':')[1]
    os_cert_bundle = re.sub(' |"|\n', '', openssl_dir)
    os_cert_bundle_path = os_cert_bundle + '/cert.pem'

    # Symlink to OpenSSL's root cert bundle...because if IMDS isn't there,
    # requests will take 60+ seconds to fail...which is a long time to not have
    # a root ca bundle available.
    if not os.path.exists(ca_file):
        os.symlink(os_cert_bundle_path, ca_file)
        logging.info(f'Symlinked {ca_file} to {os_cert_bundle_path}.');

    imds_bundle = get_imds_cert_bundle(url_cert_bundle)

    with open(os_cert_bundle_path, "rb") as fp:
        openssl_bundle = fp.read().decode("utf-8")

    global prev_openssl_bundle
    if prev_openssl_bundle == openssl_bundle:
        logging.debug(f'No changes detected in {os_cert_bundle_path}.')
    else:
        logging.info(f'Changed detected in {os_cert_bundle_path}.')

    global prev_imds_bundle
    if prev_imds_bundle == imds_bundle:
        logging.debug(f'No changes detected in {url_cert_bundle}.')
    else:
        logging.info(f'Changed detected in {url_cert_bundle}.')

    if prev_openssl_bundle != openssl_bundle or prev_imds_bundle != imds_bundle:
        with open(temp_ca_file, 'w', encoding="utf-8") as tf:
            tf.write(openssl_bundle)
            tf.write("\n\n###-------------------------------###\n\n")
            tf.write(imds_bundle)

        lck = MountLock()
        lck.acquire()
        os.rename(temp_ca_file, ca_file)
        subp.run("systemctl restart oci-fss-0*".split(' '))
        lck.release()
        logging.info(f'Updated {ca_file}.')
        prev_openssl_bundle = openssl_bundle
        prev_imds_bundle = imds_bundle
    else:
        logging.debug(f'No root cert bundle changes.')


def cert_monitor(sleep_time, ca_file, url_cert_bundle):
    while True:
        logging.debug('Generating cert bundle...')
        temp_ca_file = ca_file + "_temp"

        update_latest_cert_bundle(url_cert_bundle, ca_file, temp_ca_file)

        if os.path.exists(temp_ca_file):
            os.remove(temp_ca_file)

        logging.debug(f'Sleeping for {sleep_time} seconds...')
        time.sleep(sleep_time)

if __name__ == "__main__":
    try:
        with open("/etc/oci-fss-utils.d/config.json", "r", encoding="utf-8") as config_file:
            json_data = json.load(config_file)
            sleep_time = json_data["certMonitorSleepTime"]
            ca_file = json_data["caFile"]
            url_cert_bundle = json_data["urlCertBundle"]

            if not isinstance(sleep_time, int):
                print("Got sleep time as non-integer value, resetting it back to 30.")
                sleep_time = 30
            if sleep_time < 2:
                print("Got sleep time less than the lower bound of 2, resetting it back to 30.")
                sleep_time = 30
    except Exception as error_message:
        logging.exception(error_message)
        logging.info("Resetting the sleep_time back to 30 seconds.")
        sleep_time = 30

    cert_monitor(sleep_time, ca_file, url_cert_bundle)

