#!/usr/bin/python3

# Copyright (c) 2025, Oracle and/or its affiliates.
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
#
# This code is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 2 only, as
# published by the Free Software Foundation.
#
# This code is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
# version 2 for more details (a copy is included in the LICENSE file that
# accompanied this code).
#
# You should have received a copy of the GNU General Public License version
# 2 along with this work; if not, see <https://www.gnu.org/licenses/>.
#
# Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
# or visit www.oracle.com if you need additional information or have any
# questions.

import sys
import os
import drgn
from drgn.helpers.linux.pid import find_task, for_each_task
from drgn.helpers.linux.fs import d_path
from drgn.helpers.linux.mm import for_each_vma
from drgn.helpers.linux.xarray import xa_load, xa_is_value
import argparse
import re
import struct
import time

def signed_to_unsigned_64(value):
    if value < 0:
        value += (1 << 64)
    return value

class_page_table = None
def initiailize(prog):
    assert('init_uts_ns' in prog)
    init_uts_ns = prog['init_uts_ns']
    if init_uts_ns.name.machine.string_() != b'x86_64':
        print("Architecture %s is not supported" % init_uts_ns.name.machine.string_().decode())
        sys.exit(1)

    global class_page_table
    class_page_table = page_table_x86_64
    class_page_table.class_init(prog)

class vm_area_struct:
    # check kernels to make sure the number if right for every kernel versions
    VM_HUGETLB = 0x00400000

class page_table_x86_64:
    # page flags. NOTE: verify if all kernel version have same value
    _PAGE_PRESENT       = 1 << 0
    _PAGE_RW            = 1 << 1
    _PAGE_USER          = 1 << 2
    _PAGE_PWT           = 1 << 3 # write through
    _PAGE_PCD           = 1 << 4 # page cache disabled
    _PAGE_ACCESSED      = 1 << 5
    _PAGE_DIRTY         = 1 << 6
    _PAGE_PSE           = 1 << 7 # big page
    _PAGE_GLOBAL        = 1 << 8
    _PAGE_SOFTW1        = 1 << 9
    _PAGE_SOFTW2        = 1 << 10
    _PAGE_SOFTW3        = 1 << 11
    _PAGE_PAT_LARGE     = 1 << 12 # 2M or 1G page
    _PAGE_SOFTW4        = 1 << 58 # also _PAGE_BIT_DEVMAP
    _PAGE_PKEY_BIT0     = 1 << 59
    _PAGE_PKEY_BIT1     = 1 << 60
    _PAGE_PKEY_BIT2     = 1 << 61
    _PAGE_PKEY_BIT3     = 1 << 62
    _PAGE_NX            = 1 << 63

    _PAGE_PROTNONE      = _PAGE_GLOBAL
    _PAGE_SPECIAL       = _PAGE_SOFTW1
    _PAGE_CPA_TEST      = _PAGE_SOFTW1 # not a typo
    _PAGE_SOFT_DIRTY    = _PAGE_SOFTW3 # dirty tracking
    _PAGE_DEVMAP        = _PAGE_SOFTW4

    PRESENT_OR_PROTNONE = _PAGE_PRESENT | _PAGE_PROTNONE

    _PAGE_SWP_SOFT_DIRTY    = _PAGE_RW

    _PAGE_KNL_ERRATUM_MASK = _PAGE_DIRTY | _PAGE_ACCESSED
    MASK_PAGE_KNL_ERRATUM_MASK = signed_to_unsigned_64(~_PAGE_KNL_ERRATUM_MASK)

    _PAGE_ENC = None
    _KERNPG_TABLE = None

    SWP_TYPE_BITS   = 5
    _PAGE_BIT_PROTNONE = 8
    SWP_OFFSET_FIRST_BIT = _PAGE_BIT_PROTNONE + 1
    SWP_OFFSET_SHIFT      =  SWP_OFFSET_FIRST_BIT + SWP_TYPE_BITS
    SWAP_TYPE_SHIFT = 64 - SWP_TYPE_BITS

    MAX_SWAPFILES_SHIFT = 5
    SWP_DEVICE_NUM = 2
    SWP_MIGRATION_NUM = 2
    SWP_HWPOISON_NUM = 1
    MAX_SWAPFILES = (1 << MAX_SWAPFILES_SHIFT) - SWP_DEVICE_NUM - SWP_MIGRATION_NUM - SWP_HWPOISON_NUM

    PAGE_SHIFT = None
    PAGE_SIZE = None
    __PHYSICAL_MASK = None

    PGDIR_SHIFT = None
    PGDIR_SIZE = None
    IS_LEVEL5 = None
    PTRS_PER_PGD = 512

    P4D_SHIFT = 39
    P4D_SIZE = 1 << P4D_SHIFT
    P4D_MASK = signed_to_unsigned_64(~(P4D_SIZE -1))
    PTRS_PER_P4D = None

    PUD_SHIFT = 30
    PUD_SIZE = 1 << PUD_SHIFT
    PUD_MASK = signed_to_unsigned_64(~(PUD_SIZE - 1))
    PTRS_PER_PUD = 512
    PHYSICAL_PUD_PAGE_MASK = None

    PMD_SHIFT = 21
    PMD_SIZE = 1 << PMD_SHIFT
    PMD_MASK = signed_to_unsigned_64(~(PMD_SIZE - 1))
    PTRS_PER_PMD = 512

    PAGE_OFFSET = None
    PTE_PFN_MASK = None
    PTE_FLAGS_MASK = None
    PTRS_PER_PTE = 512

    force_walk_vma = False # force every vma for swapped ranges no matter /proc/<pid>/smaps shows swap or not.

    @classmethod
    def class_init(cls, drgn_prog):
        if drgn_prog['__pgtable_l5_enabled']:
            cls.PGDIR_SHIFT = 48
            cls.IS_LEVEL5 = True
        else:
            cls.PGDIR_SHIFT = 39
            cls.IS_LEVEL5 = False

        cls._PAGE_ENC = drgn_prog['sme_me_mask'].value_()
        cls._KERNPG_TABLE = cls._PAGE_PRESENT | cls._PAGE_RW | cls._PAGE_ACCESSED | cls._PAGE_DIRTY | cls._PAGE_ENC
        cls.PAGE_SHIFT = drgn_prog['PAGE_SHIFT'].value_()
        cls.PAGE_SIZE = drgn_prog['PAGE_SIZE'].value_()
        cls.PAGE_MASK = signed_to_unsigned_64(~(cls.PAGE_SIZE -1))

        cls.PGD_T_SIZE = drgn.sizeof(drgn_prog.type("pgd_t"))
        assert(cls.PGD_T_SIZE == 8) # otherwise, we need to change the method to read pgd_t type. same for p4d_t, pud_t, pmd_t and pte_t
        cls.read_pgd_t = drgn_prog.read_u64
        cls.P4D_T_SIZE = drgn.sizeof(drgn_prog.type("p4d_t"))
        assert(cls.P4D_T_SIZE == 8)
        cls.read_p4d_t = drgn_prog.read_u64
        cls.PUD_T_SIZE = drgn.sizeof(drgn_prog.type("pud_t"))
        assert(cls.PUD_T_SIZE == 8)
        cls.read_pud_t = drgn_prog.read_u64
        cls.PMD_T_SIZE = drgn.sizeof(drgn_prog.type("pmd_t"))
        assert(cls.PMD_T_SIZE == 8)
        cls.read_pmd_t = drgn_prog.read_u64
        cls.PTE_T_SIZE = drgn.sizeof(drgn_prog.type("pte_t"))
        assert(cls.PTE_T_SIZE == 8)
        cls.read_pte_t = drgn_prog.read_u64

        cls.PGDIR_SIZE = 1 << cls.PGDIR_SHIFT
        cls.PGDIR_MASK = signed_to_unsigned_64(~(cls.PGDIR_SIZE - 1))

        cls.PTRS_PER_P4D = drgn_prog['ptrs_per_p4d'].value_()

        cls.__PHYSICAL_MASK = drgn_prog['physical_mask'].value_()
        cls.PHYSICAL_PUD_PAGE_MASK = cls.PUD_MASK & cls.__PHYSICAL_MASK
        cls.PHYSICAL_PMD_PAGE_MASK = cls.PMD_MASK & cls.__PHYSICAL_MASK
        cls.PHYSICAL_PAGE_MASK = cls.PAGE_MASK & cls.__PHYSICAL_MASK
        cls.PTE_PFN_MASK = cls.PHYSICAL_PAGE_MASK
        cls.PTE_FLAGS_MASK = signed_to_unsigned_64(~cls.PTE_PFN_MASK)
        cls.PAGE_OFFSET = drgn_prog['page_offset_base'].value_()

        cls.is_pmd_swap = cls.is_pte_swap
        cls.ls_pud_swap = cls.is_pte_swap

    def __init__(self, task):
        self.mm = task.mm
        if self.mm.value_() == 0:
            return

        self.pgd_addr = self.mm.pgd.value_()
        self.vma = None
        self.swap_cb = None
        self.swap_cb_param = None

    def va(self, address):
        return address + self.PAGE_OFFSET

    def general_next_boundary_or_end(self, addr, area_size, area_mask, end):
        boundary = (addr + area_size) & area_mask
        if boundary < end:
            return boundary
        return end

# PGD
    # check if the given pgd stand for a hole.
    # pgd_value     -- value of pgd
    def pgd_none(self, pgd_value):
        # level 5 needs more work
        assert(not self.IS_LEVEL5)
        return False

    def pgd_bad(self, pgd_addr):
        # level 5 needs more work
        assert(not self.IS_LEVEL5)
        return False

    # get the pgd index of the given address
    def pgd_index(self, address):
        return (address >> self.PGDIR_SHIFT) & (self.PTRS_PER_PGD - 1)

    # get the pgd pointer/address for the specified address
    def pgd_offset_address(self, address):
        return self.pgd_addr + self.pgd_index(address) * self.PGD_T_SIZE

    # get address of next boundary or _end_ if the later comes earlier.
    def pgd_next_boundary_or_end(self, addr, end):
        return self.general_next_boundary_or_end(addr, self.PGDIR_SIZE, self.PGDIR_MASK, end)

    # walk the pgd range for swapped-out entries and call the callback function when found.
    #   addr        -- the starting of the address space range, in bytes
    #   end         -- the end of the range (exclusive), in bytes
    def walk_pgd_range_for_swap(self, addr, end):
        pgd_addr = self.pgd_offset_address(addr)
        while addr != end:
            next_addr = self.pgd_next_boundary_or_end(addr, end)
            pgd_value = self.read_pgd_t(pgd_addr)
            if not (self.pgd_none(pgd_value) or self.pgd_bad(pgd_value)):
                self.walk_p4d_range_for_swap(pgd_addr, addr, next_addr)
            addr = next_addr
            pgd_addr += self.PGD_T_SIZE

# P4D
    # p4d_value     -- the value of p4d
    def p4d_pfn_mask(self, p4d_value):
        return self.PTE_PFN_MASK

    def p4d_flags_mask(self, p4d_value):
        return signed_to_unsigned_64(~self.p4d_pfn_mask(p4d_value))

    def p4d_flags(self, p4d_value):
        return p4d_value & self.p4d_flags_mask(p4d_value)

    # check if the given p4d is a hole.
    def p4d_none(self, p4d_value):
        #return p4d_value & signed_to_unsigned_64(~(self._PAGE_KNL_ERRATUM_MASK )) == 0
        return p4d_value & self.MASK_PAGE_KNL_ERRATUM_MASK == 0

    # check for unexpected flag bits
    def p4d_bad(self, p4d_value):
        ignore_flags = self._KERNPG_TABLE | self._PAGE_USER | self._PAGE_NX
        return self.p4d_flags(p4d_value) & signed_to_unsigned_64(~ignore_flags) != 0

    # return the pointer/address of p4d for the specified address
    def p4d_offset_address(self, pgd_addr, addr):
        if not self.IS_LEVEL5:
            return pgd_addr
        else:
            # more code needed for level 5
            raise

    def p4d_next_boundary_or_end(self, addr, end):
        return self.general_next_boundary_or_end(addr, self.P4D_SIZE, self.P4D_MASK, end)

    # walk the p4d range for swapped-out entries and call the callback function when found.
    #   pgd_addr    -- location where p4d value is stored
    #   addr        -- the starting of the address space range, in bytes
    #   end         -- the end of the range (exclusive), in bytes
    def walk_p4d_range_for_swap(self, pgd_addr, addr, end):
        p4d_addr = self.p4d_offset_address(pgd_addr, addr)
        while addr != end:
            next_addr = self.p4d_next_boundary_or_end(addr, end)
            p4d_value = self.read_p4d_t(p4d_addr)
            if not (self.p4d_none(p4d_value) or self.p4d_bad(p4d_value)):
                self.walk_pud_range_for_swap(p4d_value, addr, next_addr)
            addr = next_addr
            p4d_addr += self.P4D_T_SIZE

    # get the virtual address of the location where the puds are stored.
    def p4d_page_vaddr(self, p4d_value):
            return self.va(p4d_value & self.p4d_pfn_mask(p4d_value))

# PUD
    def pud_flags(self, pud_value):
        return pud_value & signed_to_unsigned_64(~(self.pud_pfn_mask(pud_value)))

    def pud_index(self, address):
        return (address >> self.PUD_SHIFT) & (self.PTRS_PER_PUD - 1)

    # check if the given pud stands for a hole
    def pud_none(self, pud_value):
        return (pud_value & signed_to_unsigned_64(~(self._PAGE_KNL_ERRATUM_MASK))) == 0

    # get the kernel virtual address of the pud which stands for the given user space address
    def pud_offset_address(self, p4d_value, address):
        return self.p4d_page_vaddr(p4d_value) + self.pud_index(address) * self.P4D_T_SIZE

    def pud_next_boundary_or_end(self, addr, end):
        return self.general_next_boundary_or_end(addr, self.PUD_SIZE, self.PUD_MASK, end)

    def pud_pfn_mask(self, pud_value):
        if pud_value & self._PAGE_PSE:
            return self.PHYSICAL_PUD_PAGE_MASK
        return self.PTE_PFN_MASK

    def pud_bad(self, pud_value):
        return self.pud_flags(pud_value) & signed_to_unsigned_64(~(self._KERNPG_TABLE | self._PAGE_USER)) != 0

    def walk_pud_range_for_swap(self, p4d_value, addr, end):
        pud_addr = self.pud_offset_address(p4d_value, addr)
        while addr != end:
            next_addr = self.pud_next_boundary_or_end(addr, end)
            pud_value = self.read_pud_t(pud_addr)

            if self.pud_none(pud_value) or self.pud_bad(pud_value):
                addr = next_addr
                pud_addr += self.PUD_T_SIZE
                continue

            # pud points to pages directly
            if pud_value & self._PAGE_PSE:
                if self.is_pud_swap(pud_value):
                    self.swap_cb(addr, self.PAGE_SIZE * self.PTRS_PER_PMD * self.PTRS_PER_PUD, addr, self.swap_cb_param)
            else:
                self.walk_pmd_range_for_swap(pud_value, addr, next_addr)
            addr = next_addr
            pud_addr += self.PUD_T_SIZE

    def pud_page_vaddr(self, pud_value):
        return self.va(pud_value & self.pud_pfn_mask(pud_value))

# PMD
    def pmd_none(self, pmd_value):
        return pmd_value & signed_to_unsigned_64(~self._PAGE_KNL_ERRATUM_MASK) == 0

    def pmd_pfn_mask(self, pmd_value):
        if pmd_value & self._PAGE_PSE:
            return self.PHYSICAL_PMD_PAGE_MASK
        return self.PTE_PFN_MASK

    def pmd_flags_mask(self, pmd_value):
        return signed_to_unsigned_64(~self.pmd_pfn_mask(pmd_value))

    def pmd_flags(self, pmd_value):
        return pmd_value & self.pmd_flags_mask(pmd_value)

    def pmd_bad(self, pmd_value):
        return self.pmd_flags(pmd_value) & signed_to_unsigned_64(~self._PAGE_USER) != self._KERNPG_TABLE

    def pmd_index(self, address):
        return (address >> self.PMD_SHIFT) & (self.PTRS_PER_PMD - 1)

    # get the kernel virtual address of the pmd for the give user space address
    def pmd_offset_address(self, pud_value, address):
        return self.pud_page_vaddr(pud_value) + self.pmd_index(address) * self.PMD_T_SIZE

    # similar kernel function: pmd_addr_end
    def pmd_next_boundary_or_end(self, addr, end):
        return self.general_next_boundary_or_end(addr, self.PMD_SIZE, self.PMD_MASK, end)

    def walk_pmd_range_for_swap(self, pud_addr, addr, end):
        pmd_addr = self.pmd_offset_address(pud_addr, addr)
        while addr != end:
            next_addr = self.pmd_next_boundary_or_end(addr, end)
            pmd_value = self.read_pmd_t(pmd_addr)
            if self.pmd_none(pmd_value) or self.pmd_bad(pmd_value):
                addr = next_addr
                pmd_addr += self.PMD_T_SIZE
                continue

            # pmd points to pages directly
            if pmd_value & self._PAGE_PSE:
                if self.is_pmd_swap(pmd_value):
                    self.swap_cb(addr, self.PAGE_SIZE * self.PTRS_PER_PMD, self.swap_cb_param)
            else:
                self.walk_pte_range_for_swap(pmd_value, addr, next_addr)
            addr = next_addr
            pmd_addr += self.PMD_T_SIZE

    def pmd_page_vaddr(self, pmd_value):
        return self.va(pmd_value & self.pmd_pfn_mask(pmd_value))

# PTE
    def pte_none(self, pte_value):
        return pte_value & self.MASK_PAGE_KNL_ERRATUM_MASK == 0

    def pte_index(self, addr):
        return (addr >> self.PAGE_SHIFT) & (self.PTRS_PER_PTE - 1)

    def pte_offset_address(self, pmd_value, address):
        return self.pmd_page_vaddr(pmd_value) + self.pte_index(address) * self.PTE_T_SIZE

    # get the page index for the given address assuming the vma is not a "HUGE" one.
    def linear_page_index_nohuge(self, addr):
        idx = (addr - self.vma.vm_start) >> self.PAGE_SHIFT
        return idx + self.vma.vm_pgoff

    def is_pte_swap(self, pte_value, addr):
        # check if present
        if pte_value & self.PRESENT_OR_PROTNONE:
            return False

        # now not _present_
        # check if none
        if pte_value & self.MASK_PAGE_KNL_ERRATUM_MASK:
            # swap type
            if (pte_value >> self.SWAP_TYPE_SHIFT) >= self.MAX_SWAPFILES: # non_swap_entry() true
                return False
            return True

        # now _none_
        # check for shared memory
        if self.vma.vm_file.value_() == 0 or self.vma.vm_file.f_mapping.value_() == 0:
            return False

        page_idx = self.linear_page_index_nohuge(addr)
        page = xa_load(self.vma.vm_file.f_mapping.i_pages.address_of_(), page_idx)
        if page.value_() == 0:
            return False

        return xa_is_value(page)

    def is_pte_swap_for_shared(self, pte_value, addr):
        # check for shared memory
        if self.vma.vm_file.value_() == 0 or self.vma.vm_file.f_mapping.value_() == 0:
            return False

        page_idx = self.linear_page_index_nohuge(addr)
        page = xa_load(self.vma.vm_file.f_mapping.i_pages.address_of_(), page_idx)
        if page.value_() == 0:
            return False

        return xa_is_value(page)

    def walk_pte_range_for_swap(self, pmd_value, addr, end):
        nr_pages = (end - addr) // self.PAGE_SIZE
        size = nr_pages * self.PTE_T_SIZE
        pte_addr = self.pte_offset_address(pmd_value, addr)
        buf = prog.read(pte_addr, size)
        # use unpack to read out all pte values to get better performance
        pte_values = struct.unpack('<' + 'Q' * nr_pages, buf)

        nr_batch = 0
        first_addr = 0
        for pte_value in pte_values:
            # inline the checks for better performance rather than calling is_pte_swap() directly.
            if pte_value & self.PRESENT_OR_PROTNONE:
                is_swap = False
            else:
                if pte_value & self.MASK_PAGE_KNL_ERRATUM_MASK:
                    if (pte_value >> self.SWAP_TYPE_SHIFT) >= self.MAX_SWAPFILES:
                        is_swap = False
                    else:
                        is_swap = True
                else:
                    is_swap = self.is_pte_swap_for_shared(pte_value, addr)

            if is_swap:
                if nr_batch == 0:
                    first_addr = addr
                nr_batch += 1
            elif nr_batch:
                self.swap_cb(first_addr, self.PAGE_SIZE * nr_batch, self.swap_cb_param)
                nr_batch = 0
            addr += self.PAGE_SIZE
        if nr_batch:
            self.swap_cb(first_addr, self.PAGE_SIZE * nr_batch, self.swap_cb_param)
        # end of walk_pte_range_for_swap()

    # vma       -- type of vm_area_struct
    # swap_cb   -- callback function for swap ptes. call it as cb(swap_size, cb_param).
    # swap_cb_param -- the paramter to swap_cb
    def walk_vma_for_swap(self, vma, cb, cb_param):
        if self.mm.value_() == 0:
            return

        self.vma = vma
        self.swap_cb = cb
        self.swap_cb_param = cb_param

        # huage pages can't be swapped out
        if self.vma.vm_flags & vm_area_struct.VM_HUGETLB:
            return

        self.walk_pgd_range_for_swap(vma.vm_start.value_(), vma.vm_end.value_())
        # end of walk_vma_for_swap()

def print_process_swap_header():
    print("%12s %18s %s" % ("PID", "SIZE(K)", "   COMM"))

# swapsize  -- size in bytes of swapped out memory
def print_process_swap(task, swapsize):
    comm = task.comm.string_()
    try:
        comm = comm.decode()
    except:
        pass

    print("%12d %18d    %s" % (task.pid, swapsize // 1024, comm))


# for comparing two running of swap info collection.
# maintains swap info for all processes.
class PROCESS_SWAP_INFO:
    ROUND1 = 0
    ROUND2 = 1
    # round 1 and round 2 swap info for processes. For each dictionary, key is pid, value is instance of PROCESS_SWAP_INFO
    rounds = [{},{}]
    # set this to ROUND2 after round 1scanning and before starting the round 2
    cur_round = ROUND1
    # stores involved pid only
    pid_d = {}

    # add a process
    # task   -- the task to add
    @classmethod
    def add_process(cls, task):
        pid = task.pid.value_()
        cls.pid_d[pid] = pid
        d = cls.rounds[cls.cur_round]
        obj_p_info = PROCESS_SWAP_INFO(task)
        d[pid] = obj_p_info
        return obj_p_info

    def __init__(self, task):
        # start_time is used together with pid to determine if the processes are identical
        self.start_time = task.start_time.value_()
        self.task = task
        # ranges -- stores the space ranges of the process. pages in same range are all swapped out or are all not swapped out.
        # the ranges are in address order, elements are tuple of (start_addr:int, lenght_in_bytes:int, swapped_out:Bool)
        # first range starts from 0.
        self.ranges = []
        # next_byte -- used during building the ranges. It is the next byte of last swapped-out byte.
        self.next_byte = 0

    # add a swapped-out range to process info.
    #   addr    -- the starting address of the range
    #   bytes   -- the size of the range in bytes
    def set_swapped_range(self, addr:int, bytes:int):
        # add a hole if the new swapped out range is not contiguous to the last swapped out one.
        # addr should not equal to self.next_byte
        if self.next_byte != addr:
            self.ranges.append((self.next_byte, addr - self.next_byte, False))
        self.ranges.append((addr, bytes, True))
        self.next_byte = addr + bytes

    # Compare the round 1 swap info with round 2, get the numbers of newly swapped in and out in bettween.
    # NOTE: the old/new.ranges are destroied after the operation.
    #   old -- the round 1 instance of PROCESS_SWAP_INFO
    #   new -- the round 2 instance of PROCESS_SWAP_INFO
    @classmethod
    def count_newly_swapped_in_and_out(cls, old, new):
        new_ranges = new.ranges
        old_ranges = old.ranges
        diff_swapped_in = 0
        diff_swapped_out = 0

        # Iterates new/old range list,
        # trim current new range or old range to the minimum length of them, put the leftover to its original
        # list for next iteration.
        # And operate on the minumum range:
        # 1. if swap status didn't change, we are done.
        # 2. add diff_swapped_in or add diff_swapped_out accordingly
        while new_ranges:
            (new_start, new_len, new_swapped_out) = new_ranges.pop(0)
            if not old_ranges:
                if new_swapped_out:
                    diff_swapped_out += new_len
                continue

            (old_start, old_len, old_swapped_out) = old_ranges.pop(0)
            assert(new_start == old_start)
            if new_len > old_len:
                min_len = old_len
                new_start += min_len
                new_len -= min_len
                new_ranges.insert(0, (new_start, new_len, new_swapped_out))
            elif old_len > new_len:
                min_len = new_len
                old_start += min_len
                old_len -= min_len
                old_ranges.insert(0, (old_start, old_len, old_swapped_out))
            else: # new_len == old_len
                min_len = new_len

            # no change of the swap statu
            if new_swapped_out == old_swapped_out:
                continue

            if new_swapped_out:
                diff_swapped_out += min_len
            else:
                diff_swapped_in += min_len

        # new range is empty now, check for swapped out ranges on old (process address space reduced).
        # check for swapped out ranges on old -- if so, treat them are swapped in for round 2
        for (old_start, old_len, old_swapped_out) in old_ranges:
            if old_swapped_out:
                diff_swapped_in += old_len

        return (diff_swapped_in, diff_swapped_out)
        # end of count_newly_swapped_in_and_out()

    # report the diff of the two run of swap info collection.
    # it reports:
    # 1. how many newly swapped in
    # 2. how many newly swapped out
    @classmethod
    def report_process_swap_diff(cls, pid):
        # we only have round 2
        if pid not in cls.rounds[cls.ROUND1]:
            return

        # only have round 1
        if pid not in cls.rounds[cls.ROUND2]:
            return

        # report the diff of two runnings
        old = cls.rounds[cls.ROUND1][pid]
        new = cls.rounds[cls.ROUND2][pid]

        # though it should be rare  case, they are not same process.
        # the original process terminated and new process reused the same pid
        if old.start_time != new.start_time:
            return

        (newly_swapped_in, newly_swapped_out) = cls.count_newly_swapped_in_and_out(old, new)
        if newly_swapped_in:
            print_process_swap(old.task, -newly_swapped_in)
        if newly_swapped_out:
            print_process_swap(old.task, newly_swapped_out)
        # end of report_process_swap_diff()

    @classmethod
    def report_processes_swap_diff(cls):
        print_process_swap_header()
        for pid in cls.pid_d:
            cls.report_process_swap_diff(pid)
    # end of class PROCESS_SWAP_INFO

# Run this on live system only.
# This looks at /proc/<pid>/smaps
# and return the address space ranges which have swapped sub-ranges.
def get_swapped_ranges_on_live(task):
    proc_file_path = "/proc/%d/smaps" % task.pid
    pattern_addr = r'^([0-9A-Fa-f]+)-([0-9A-Fa-f]+)' # process space address range
    patern_swap = r'^Swap:\s+(\d+)'
    ret_d = {}
    with open(proc_file_path, 'r') as file:
        buffer = file.read()   # read all file out as it's small
        matches = re.finditer(pattern_addr, buffer, re.MULTILINE)
        for match in matches:
            addr_start = int(match.group(1), 16)
            addr_end = int(match.group(2), 16)

            match2 = re.search(patern_swap, buffer[match.end():], re.MULTILINE)
            swap_size = int(match2.group(1))
            if swap_size > 0:
                ret_d[addr_start] = addr_end
    return ret_d

def list_swapped_processes():
    def key_swap_size(input):
        return input[1]

    MM_SWAPENTS = 2
    l = []

    for task in for_each_task():
        if task.mm.value_() == 0:
            continue
        # number of pages swapped out
        swap_ents = task.mm.rss_stat.count[MM_SWAPENTS].counter.value_()
        if swap_ents == 0:
            continue

        if swap_ents:
            l.append((task, swap_ents))
    # sort by swap size reversely
    l.sort(key = key_swap_size, reverse=True)
    print_process_swap_header()
    for (task, swap_ents) in l:
        print_process_swap(task, swap_ents * class_page_table.PAGE_SIZE)
    # end of list_swapped_processes()

def better_unit(bytes):
    n_kbytes = bytes // 1024
    if n_kbytes < 1024:
        return "%d KB" % n_kbytes
    mb = n_kbytes / 1024.0
    if mb < 1024:
        return "%.2f MB" % mb
    gb = mb / 1024.0
    return "%.2f GB" % gb

# task      -- task_struct object
# save      -- set when in diff mode
def show_task_swap_info(task, save = False):
    # call back function when the range addr/swap_size is swapped out
    # swap_size is in bytes
    def add_swap_size(addr, swap_size, d):
        d['swap_size'] += swap_size
        if d['first_swap_addr'] is None:
            d['first_swap_addr'] = addr
            d['last_swap_addr'] = addr + swap_size
            return

        if d['last_swap_addr'] != addr:
            if save:
                obj_p_info = d['obj_p_info']
                obj_p_info.set_swapped_range(d['first_swap_addr'], d['last_swap_addr'] - d['first_swap_addr'])
            else:
                print("%20s %20d" % ("0x%x" % d['first_swap_addr'], (d['last_swap_addr'] - d['first_swap_addr'])//1024));
            d['first_swap_addr'] = addr

        d['last_swap_addr'] = addr + swap_size
        # end of add_swap_size

    if not save:
        print("%20s %20s" % ("ADDRESS", "SIZE (K)"))
    if task.mm.value_() == 0:
        if not save:
            print("0KB of this process memory got swapped out")
        return

    pt = class_page_table(task)
    if prog.flags & drgn.ProgramFlags.IS_LIVE:
        swapped_vmas = get_swapped_ranges_on_live(task)

    file_swapped_size_map = {}
    total_swapped_Kbytes = 0

    obj_p_info = PROCESS_SWAP_INFO.add_process(task)
    # last_swap_addr: next byte of the swapped range
    process_d = {"swap_size":0, "first_swap_addr":None, "last_swap_addr":None, 'obj_p_info':obj_p_info}
    for vma in for_each_vma(task.mm):
        # /proc/<pid>smaps didn't show any swap for this vma, skip it to save time
        if prog.flags & drgn.ProgramFlags.IS_LIVE and vma.vm_start.value_() not in swapped_vmas:
            continue

        pt.walk_vma_for_swap(vma, add_swap_size, process_d)
    # trigger print/cb for the last swapped-out range in process_d
    add_swap_size(-1, 0, process_d)

    if not save:
        print(f"{better_unit(process_d['swap_size'])} of this process memory got swapped out")
    # end of show_task_swap_info()

help_msg = '''
oled-swapinfo [-c path_to_vmcore] [-s path_to_vmlinux] [-p pid] [-d N]
command lines:
  swapinfo                -- List processes which has swapped-out memory
  swapinfo -p pid         -- Dump detailed swapinfo for the specified process
  swapinfo -d N           -- Dump the swap-in/out that happened in the last N seconds
'''
def help():
    print(help_msg)
    sys.exit(1)

def drgn__main__():
    initiailize(prog)
    p = argparse.ArgumentParser(prog = "oled-swapinfo", description = "Get system swap information", usage = help_msg)
    p.add_argument("-p", "--pid",
                dest = "pid",
                type = int,
                help = "specified the pid for which to dump swap info, default is 1",
                default = False);

    p.add_argument("-d", "--diff",
                dest = "N",
                type = int,
                help = "scan processe(s) twice and report the swap usage difference",
                default = False)

    args = p.parse_args()
    if args.pid:
        task = find_task(prog, args.pid)
        if task.value_() == 0:
            print("Task {args.pid} not found")
            sys.exit(1)
        else:
            show_task_swap_info(task)
            sys.exit(0)

    if args.N:
        for task in for_each_task():
            show_task_swap_info(task, save = True)

        PROCESS_SWAP_INFO.cur_round = PROCESS_SWAP_INFO.ROUND2
        time.sleep(args.N)

        for task in for_each_task():
            show_task_swap_info(task, save = True)

        PROCESS_SWAP_INFO.report_processes_swap_diff()
        sys.exit(0)

    list_swapped_processes()
    sys.exit(0)

def __main__():
    if os.geteuid() != 0:
        print("Please run oled-swapinfo as root!")
        sys.exit(1)

    DRNG_BIN = "/usr/bin/drgn"
    DRGN_SCRIPT = "/usr/libexec/oled-tools/swapinfo"
    p = argparse.ArgumentParser(prog = "oled-swapinfo", description = "Get system swap information", usage = help_msg)
    p.add_argument("-p", "--pid",
                dest = "pid",
                type = int,
                help = "specify the pid of process to dump swap info for",
                default = 0);

    p.add_argument("-d", "--diff",
                dest = "N",
                type = int,
                help = "scan processes swap info twice with a N-second interval and report the difference",
                default = 0)

    options, args = p.parse_known_args()
    script_args = []
    if options.pid:
        script_args.append("-p")
        script_args.append("%d" % options.pid)
    if options.N:
        script_args.append("-d")
        script_args.append("%d" % options.N)

    exec_args = ["oled-swapinfo", "-C"] + args + [DRGN_SCRIPT] + script_args
    os.execv(DRNG_BIN, exec_args)

if __name__ == "__main__":
    if 'prog' in globals():
        drgn__main__()
    else:
        __main__()
