#!/usr/bin/env python3

import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import date
import io
import subprocess
import tarfile
from uuid import uuid4

from google.cloud import compute
from google.cloud import exceptions
from google.cloud import storage

IPXE_STORAGE_PREFIX = 'ipxe-upload-temp-'

FEATURE_GVNIC = compute.GuestOsFeature(type_="GVNIC")
FEATURE_IDPF = compute.GuestOsFeature(type_="IDPF")
FEATURE_UEFI = compute.GuestOsFeature(type_="UEFI_COMPATIBLE")

POLICY_PUBLIC = compute.Policy(bindings=[{
    "role": "roles/compute.imageUser",
    "members": ["allAuthenticatedUsers"],
}])

def delete_temp_bucket(bucket):
    """Remove temporary bucket"""
    assert bucket.name.startswith(IPXE_STORAGE_PREFIX)
    for blob in bucket.list_blobs(prefix=IPXE_STORAGE_PREFIX):
        assert blob.name.startswith(IPXE_STORAGE_PREFIX)
        blob.delete()
    if not list(bucket.list_blobs()):
        bucket.delete()

def create_temp_bucket(location):
    """Create temporary bucket (and remove any stale temporary buckets)"""
    client = storage.Client()
    for bucket in client.list_buckets(prefix=IPXE_STORAGE_PREFIX):
        delete_temp_bucket(bucket)
    name = '%s%s' % (IPXE_STORAGE_PREFIX, uuid4())
    return client.create_bucket(name, location=location)

def create_tarball(image):
    """Create raw disk image tarball"""
    tarball = io.BytesIO()
    with tarfile.open(fileobj=tarball, mode='w:gz',
                      format=tarfile.GNU_FORMAT) as tar:
        tar.add(image, arcname='disk.raw')
    tarball.seek(0)
    return tarball

def upload_blob(bucket, image):
    """Upload raw disk image blob"""
    blob = bucket.blob('%s%s.tar.gz' % (IPXE_STORAGE_PREFIX, uuid4()))
    tarball = create_tarball(image)
    blob.upload_from_file(tarball)
    return blob

def detect_uefi(image):
    """Identify UEFI CPU architecture(s)"""
    mdir = subprocess.run(['mdir', '-b', '-i', image, '::/EFI/BOOT'],
                          stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                          check=False)
    mapping = {
        b'BOOTX64.EFI': 'x86_64',
        b'BOOTAA64.EFI': 'arm64',
    }
    uefi = [
        arch
        for filename, arch in mapping.items()
        if filename in mdir.stdout
    ]
    return uefi

def image_architecture(uefi):
    """Get image architecture"""
    return uefi[0] if len(uefi) == 1 else None if uefi else 'x86_64'

def image_features(uefi):
    """Get image feature list"""
    features = [FEATURE_GVNIC, FEATURE_IDPF]
    if uefi:
        features.append(FEATURE_UEFI)
    return features

def image_name(base, uefi):
    """Calculate image name or family name"""
    suffix = ('-uefi-%s' % uefi[0].replace('_', '-') if len(uefi) == 1 else
              '-uefi-multi' if uefi else '')
    return '%s%s' % (base, suffix)

def create_image(project, basename, basefamily, overwrite, public, bucket,
                 image):
    """Create image"""
    client = compute.ImagesClient()
    uefi = detect_uefi(image)
    architecture = image_architecture(uefi)
    features = image_features(uefi)
    name = image_name(basename, uefi)
    family = image_name(basefamily, uefi)
    if overwrite:
        try:
            client.delete(project=project, image=name).result()
        except exceptions.NotFound:
            pass
    blob = upload_blob(bucket, image)
    disk = compute.RawDisk(source=blob.public_url)
    image = compute.Image(name=name, family=family, architecture=architecture,
                          guest_os_features=features, raw_disk=disk)
    client.insert(project=project, image_resource=image).result()
    if public:
        request = compute.GlobalSetPolicyRequest(policy=POLICY_PUBLIC)
        client.set_iam_policy(project=project, resource=name,
                              global_set_policy_request_resource=request)
    image = client.get(project=project, image=name)
    return image

# Parse command-line arguments
#
parser = argparse.ArgumentParser(description="Import Google Cloud image")
parser.add_argument('--name', '-n',
                    help="Base image name")
parser.add_argument('--family', '-f',
                    help="Base family name")
parser.add_argument('--public', '-p', action='store_true',
                    help="Make image public")
parser.add_argument('--overwrite', action='store_true',
                    help="Overwrite any existing image with same name")
parser.add_argument('--project', '-j', default="ipxe-images",
                    help="Google Cloud project")
parser.add_argument('--location', '-l',
                    help="Google Cloud Storage initial location")
parser.add_argument('image', nargs='+', help="iPXE disk image")
args = parser.parse_args()

# Use default family name if none specified
if not args.family:
    args.family = 'ipxe'

# Use default name if none specified
if not args.name:
    args.name = '%s-%s' % (args.family, date.today().strftime('%Y%m%d'))

# Create temporary upload bucket
bucket = create_temp_bucket(args.location)

# Use one thread per image to maximise parallelism
with ThreadPoolExecutor(max_workers=len(args.image)) as executor:
    futures = {executor.submit(create_image,
                               project=args.project,
                               basename=args.name,
                               basefamily=args.family,
                               overwrite=args.overwrite,
                               public=args.public,
                               bucket=bucket,
                               image=image): image
               for image in args.image}
    results = {futures[future]: future.result()
               for future in as_completed(futures)}

# Delete temporary upload bucket
delete_temp_bucket(bucket)

# Show created images
for image in args.image:
    result = results[image]
    print("%s (%s) %s" % (result.name, result.family, result.status))
