#!/usr/bin/env python3

import argparse
import textwrap
import time
from uuid import uuid4

from google.cloud import compute

IPXE_LOG_PREFIX = 'ipxe-log-temp-'
IPXE_LOG_MAGIC = 'iPXE LOG'
IPXE_LOG_END = '----- END OF iPXE LOG -----'

def get_log_disk(instances, project, zone, name):
    """Get log disk source URL"""
    instance = instances.get(project=project, zone=zone, instance=name)
    disk = next(x for x in instance.disks if x.boot)
    return disk.source

def delete_temp_snapshot(snapshots, project, name):
    """Delete temporary snapshot"""
    assert name.startswith(IPXE_LOG_PREFIX)
    snapshots.delete(project=project, snapshot=name)

def delete_temp_snapshots(snapshots, project):
    """Delete all old temporary snapshots"""
    filter = "name eq %s.+" % IPXE_LOG_PREFIX
    request = compute.ListSnapshotsRequest(project=project, filter=filter)
    for snapshot in snapshots.list(request=request):
        delete_temp_snapshot(snapshots, project, snapshot.name)

def create_temp_snapshot(snapshots, project, source):
    """Create temporary snapshot"""
    name = '%s%s' % (IPXE_LOG_PREFIX, uuid4())
    snapshot = compute.Snapshot(name=name, source_disk=source)
    snapshots.insert(project=project, snapshot_resource=snapshot).result()
    return name

def delete_temp_instance(instances, project, zone, name):
    """Delete log dumper temporary instance"""
    assert name.startswith(IPXE_LOG_PREFIX)
    instances.delete(project=project, zone=zone, instance=name)

def delete_temp_instances(instances, project, zone):
    """Delete all old log dumper temporary instances"""
    filter = "name eq %s.+" % IPXE_LOG_PREFIX
    request = compute.ListInstancesRequest(project=project, zone=zone,
                                           filter=filter)
    for instance in instances.list(request=request):
        delete_temp_instance(instances, project, zone, instance.name)

def create_temp_instance(instances, project, zone, family, image, machine,
                         snapshot):
    """Create log dumper temporary instance"""
    image = "projects/%s/global/images/family/%s" % (family, image)
    machine_type = "zones/%s/machineTypes/%s" % (zone, machine)
    logsource = "global/snapshots/%s" % snapshot
    bootparams = compute.AttachedDiskInitializeParams(source_image=image)
    bootdisk = compute.AttachedDisk(boot=True, auto_delete=True,
                                    initialize_params=bootparams)
    logparams = compute.AttachedDiskInitializeParams(source_snapshot=logsource)
    logdisk = compute.AttachedDisk(boot=False, auto_delete=True,
                                   initialize_params=logparams,
                                   device_name="ipxelog")
    nic = compute.NetworkInterface()
    name = '%s%s' % (IPXE_LOG_PREFIX, uuid4())
    script = textwrap.dedent(f"""
    #!/bin/sh
    tr -d '\\000' < /dev/disk/by-id/google-ipxelog-part3 > /dev/ttyS3
    echo "{IPXE_LOG_END}" > /dev/ttyS3
    """).strip()
    items = compute.Items(key="startup-script", value=script)
    metadata = compute.Metadata(items=[items])
    instance = compute.Instance(name=name, machine_type=machine_type,
                                network_interfaces=[nic], metadata=metadata,
                                disks=[bootdisk, logdisk])
    instances.insert(project=project, zone=zone,
                     instance_resource=instance).result()
    return name

def get_log_output(instances, project, zone, name):
    """Get iPXE log output"""
    request = compute.GetSerialPortOutputInstanceRequest(project=project,
                                                         zone=zone, port=4,
                                                         instance=name)
    while True:
        log = instances.get_serial_port_output(request=request).contents.strip()
        if log.endswith(IPXE_LOG_END):
            if log.startswith(IPXE_LOG_MAGIC):
                return log[len(IPXE_LOG_MAGIC):-len(IPXE_LOG_END)]
            else:
                return log[:-len(IPXE_LOG_END)]
        time.sleep(1)

# Parse command-line arguments
#
parser = argparse.ArgumentParser(description="Import Google Cloud image")
parser.add_argument('--project', '-j', default="ipxe-images",
                    help="Google Cloud project")
parser.add_argument('--zone', '-z', required=True,
                    help="Google Cloud zone")
parser.add_argument('--family', '-f', default="debian-cloud",
                    help="Helper OS image family")
parser.add_argument('--image', '-i', default="debian-12",
                    help="Helper OS image")
parser.add_argument('--machine', '-m', default="e2-micro",
                    help="Helper machine type")
parser.add_argument('instance', help="Instance name")
args = parser.parse_args()

# Construct client objects
#
instances = compute.InstancesClient()
snapshots = compute.SnapshotsClient()

# Clean up old temporary objects
#
delete_temp_instances(instances, project=args.project, zone=args.zone)
delete_temp_snapshots(snapshots, project=args.project)

# Create log disk snapshot
#
logdisk = get_log_disk(instances, project=args.project, zone=args.zone,
                       name=args.instance)
logsnap = create_temp_snapshot(snapshots, project=args.project, source=logdisk)

# Create log dumper instance
#
dumper = create_temp_instance(instances, project=args.project, zone=args.zone,
                              family=args.family, image=args.image,
                              machine=args.machine, snapshot=logsnap)

# Wait for log output
#
output = get_log_output(instances, project=args.project, zone=args.zone,
                        name=dumper)

# Print log output
#
print(output)

# Clean up
#
delete_temp_instance(instances, project=args.project, zone=args.zone,
                     name=dumper)
delete_temp_snapshot(snapshots, project=args.project, name=logsnap)
