#!/usr/bin/env python3

import argparse
import base64
from collections import namedtuple
import datetime
import json
import time
from uuid import uuid4

import alibabacloud_credentials as credentials
import alibabacloud_credentials.client
import alibabacloud_credentials.models
import alibabacloud_ecs20140526 as ecs
import alibabacloud_ecs20140526.client
import alibabacloud_ecs20140526.models
import alibabacloud_tea_openapi as openapi
import alibabacloud_tea_openapi.client
import alibabacloud_tea_openapi.models

IPXE_LOG_PREFIX = 'ipxe-log-temp-'
IPXE_LOG_TAG = 'ipxe-log-temp-snapshot'
IPXE_LOG_MAGIC = 'iPXE LOG'
IPXE_LOG_DISK_SIZE = 20
IPXE_LOG_DISK_CATEGORY = 'cloud_essd'

Clients = namedtuple('Clients', ['region', 'ecs'])

def all_clients(region):
    """Construct all per-region clients"""
    cred = credentials.client.Client()
    ecsconf = openapi.models.Config(credential=cred, region_id=region)
    clients = Clients(
        region=region,
        ecs=ecs.client.Client(ecsconf),
    )
    return clients

def get_log_disk(clients, instance):
    """Get log disk ID"""
    req = ecs.models.DescribeDisksRequest(
        region_id=clients.region,
        instance_id=instance,
    )
    rsp = clients.ecs.describe_disks(req)
    return rsp.body.disks.disk[0].disk_id

def delete_temp_snapshot(clients, snapshot):
    """Remove temporary snapshot"""
    req = ecs.models.DeleteSnapshotRequest(snapshot_id=snapshot, force=True)
    clients.ecs.delete_snapshot(req)

def delete_temp_snapshots(clients):
    """Remove all old temporary snapshots"""
    tag = ecs.models.DescribeSnapshotsRequestTag(
        key=IPXE_LOG_TAG,
        value=IPXE_LOG_TAG,
    )
    req = ecs.models.DescribeSnapshotsRequest(
        region_id=clients.region,
        tag=[tag],
    )
    rsp = clients.ecs.describe_snapshots(req)
    for snapshot in rsp.body.snapshots.snapshot or []:
        assert snapshot.snapshot_name.startswith(IPXE_LOG_PREFIX)
        delete_temp_snapshot(clients, snapshot.snapshot_id)

def create_temp_snapshot(clients, disk):
    """Create temporary snapshot"""
    name = '%s%s' % (IPXE_LOG_PREFIX, uuid4())
    tag = ecs.models.CreateSnapshotRequestTag(
        key=IPXE_LOG_TAG,
        value=IPXE_LOG_TAG,
    )
    req = ecs.models.CreateSnapshotRequest(
        disk_id=disk,
        snapshot_name=name,
        retention_days=1,
        tag=[tag],
    )
    rsp = clients.ecs.create_snapshot(req)
    snapshot = rsp.body.snapshot_id
    while True:
        time.sleep(1)
        req = ecs.models.DescribeSnapshotsRequest(
            region_id=clients.region,
            snapshot_ids=json.dumps([snapshot]),
        )
        rsp = clients.ecs.describe_snapshots(req)
        status = rsp.body.snapshots.snapshot[0].status
        if status != 'progressing':
            break
    if status != 'accomplished':
        raise RuntimeError(status)
    return snapshot

def delete_temp_instance(clients, instance, retry=False):
    """Remove temporary log dumper instance"""
    while True:
        req = ecs.models.DeleteInstanceRequest(
            instance_id=instance,
            force=True,
            force_stop=True,
        )
        try:
            rsp = clients.ecs.delete_instance(req)
        except openapi.exceptions.ClientException:
            # Very recently created instances often cannot be
            # terminated until some undocumented part of the control
            # plane decides that enough time has elapsed
            if retry:
                time.sleep(1)
                continue
            raise
        break

def delete_temp_instances(clients):
    """Remove all old temporary log dumper instances"""
    tag = ecs.models.DescribeInstancesRequestTag(
        key=IPXE_LOG_TAG,
        value=IPXE_LOG_TAG,
    )
    req = ecs.models.DescribeInstancesRequest(
        region_id=clients.region,
        tag=[tag],
    )
    rsp = clients.ecs.describe_instances(req)
    for instance in rsp.body.instances.instance or []:
        assert instance.instance_name.startswith(IPXE_LOG_PREFIX)
        delete_temp_instance(clients, instance.instance_id)

def create_temp_instance(clients, reference, snapshot, family, machine):
    """Create temporary log dumper instance"""
    req = ecs.models.DescribeInstancesRequest(
        region_id=clients.region,
        instance_ids=json.dumps([reference]),
    )
    rsp = clients.ecs.describe_instances(req)
    instance = rsp.body.instances.instance[0]
    secgroups = instance.security_group_ids.security_group_id
    vswitch = instance.vpc_attributes.v_switch_id
    name = '%s%s' % (IPXE_LOG_PREFIX, uuid4())
    sysdisk = ecs.models.RunInstancesRequestSystemDisk(
        category=IPXE_LOG_DISK_CATEGORY,
    )
    now = datetime.datetime.now(datetime.UTC)
    lifetime = datetime.timedelta(hours=1)
    release = (now + lifetime).strftime('%Y-%m-%dT%H:%M:%SZ')
    datadisk = ecs.models.RunInstancesRequestDataDisk(
        delete_with_instance=True,
        snapshot_id=snapshot,
        disk_name='ipxelog',
        size=IPXE_LOG_DISK_SIZE,
        category=IPXE_LOG_DISK_CATEGORY,
    )
    tag = ecs.models.RunInstancesRequestTag(
        key=IPXE_LOG_TAG,
        value=IPXE_LOG_TAG,
    )
    req = ecs.models.RunInstancesRequest(
        region_id=clients.region,
        image_family=family,
        instance_type=machine,
        instance_name=name,
        auto_release_time=release,
        system_disk=sysdisk,
        data_disk=[datadisk],
        security_group_ids=secgroups,
        v_switch_id=vswitch,
        tag=[tag],
    )
    rsp = clients.ecs.run_instances(req)
    return rsp.body.instance_id_sets.instance_id_set[0]

def run_command(clients, instance, command):
    """Run command on instance"""
    req = ecs.models.RunCommandRequest(
        region_id=clients.region,
        instance_id=[instance],
        type='RunShellScript',
        command_content=command,
    )
    rsp = clients.ecs.run_command(req)
    invocation = rsp.body.invoke_id
    while True:
        time.sleep(1)
        req = ecs.models.DescribeInvocationResultsRequest(
            region_id=clients.region,
            invoke_id=invocation,
        )
        rsp = clients.ecs.describe_invocation_results(req)
        result = rsp.body.invocation.invocation_results.invocation_result[0]
        if result.invoke_record_status not in ('Pending', 'Running'):
            break
    return result

def get_log_output(clients, instance):
    """Get iPXE log output"""
    output = b''
    while True:
        command = " | ".join([
            f"tr -d '\\000' < /dev/disk/by-diskseq/2-part3",
            f"tail -c +{ len(output) + 1 }"
        ])
        result = run_command(clients, instance, command)
        output += base64.b64decode(result.output)
        if not result.dropped:
            break
    log = output.decode()
    if log.startswith(IPXE_LOG_MAGIC):
        log = log[len(IPXE_LOG_MAGIC):]
    return log

def force_power_off(clients, instance):
    """Forcibly power-off instance"""
    command = " ; ".join([
        "echo 1 > /proc/sys/kernel/sysrq",
        "echo o > /proc/sysrq-trigger"
    ])
    req = ecs.models.RunCommandRequest(
        region_id=clients.region,
        instance_id=[instance],
        type='RunShellScript',
        command_content=command,
    )
    rsp = clients.ecs.run_command(req)

# Parse command-line arguments
parser = argparse.ArgumentParser(
    description="Get Alibaba Cloud disk console output"
)
parser.add_argument('--region', '-r', required=True,
                    help="AliCloud region")
parser.add_argument('--family', '-f',
                    default="acs:alibaba_cloud_linux_4_lts_x64",
                    help="Helper OS image family")
parser.add_argument('--machine', '-m', default="ecs.e-c4m1.large",
                    help="Helper machine type")
parser.add_argument('instance', help="Instance ID")
args = parser.parse_args()

# Construct clients
clients = all_clients(args.region)

# Clean up old temporary objects
delete_temp_instances(clients)
delete_temp_snapshots(clients)

# Create log disk snapshot
logdisk = get_log_disk(clients, args.instance)
logsnap = create_temp_snapshot(clients, logdisk)

# Create log dumper instance
dumper = create_temp_instance(clients, args.instance, logsnap, args.family,
                              args.machine)

# Wait for log output
output = get_log_output(clients, dumper)

# Print log output
print(output)

# Clean up
force_power_off(clients, dumper)
delete_temp_instance(clients, dumper, retry=True)
delete_temp_snapshot(clients, logsnap)
