#!/usr/bin/env python3

import argparse
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import ipaddress
from itertools import islice
import json
import time

import alibabacloud_credentials as credentials
import alibabacloud_credentials.client
import alibabacloud_credentials.models
import alibabacloud_ecs20140526 as ecs
import alibabacloud_ecs20140526.client
import alibabacloud_ecs20140526.models
import alibabacloud_ram20150501 as ram
import alibabacloud_ram20150501.client
import alibabacloud_ram20150501.models
import alibabacloud_tea_openapi as openapi
import alibabacloud_tea_openapi.client
import alibabacloud_tea_openapi.models
import alibabacloud_tea_util as util
import alibabacloud_tea_util.client
import alibabacloud_tea_util.models
import alibabacloud_vpc20160428 as vpc
import alibabacloud_vpc20160428.client
import alibabacloud_vpc20160428.models

ECS_ENDPOINT = 'ecs.aliyuncs.com'
RAM_ENDPOINT = 'ram.aliyuncs.com'

IPXE_VPC_TAG = 'ipxe-default-vpc'
IPXE_VSWITCH_TAG = 'ipxe-default-vswitch'
IPXE_SG_TAG = 'ipxe-default-sg'

IPXE_CENSORSHIP_BYPASS_ROLE_NAME = 'iPXECensorshipBypassRole'
IPXE_CENSORSHIP_BYPASS_ROLE_ASSUME_POLICY = {
    'Statement': [{
        'Action': 'sts:AssumeRole',
        'Effect': 'Allow',
        'Principal': {'Service': ['ecs.aliyuncs.com']},
    }],
    'Version': '1',
}

Clients = namedtuple('Clients', ['region', 'ecs', 'vpc'])

def all_regions():
    """Get list of all regions"""
    cred = credentials.client.Client()
    conf = openapi.models.Config(credential=cred, endpoint=ECS_ENDPOINT)
    client = ecs.client.Client(conf)
    req = ecs.models.DescribeRegionsRequest()
    rsp = client.describe_regions(req)
    regions = sorted(x.region_id for x in rsp.body.regions.region)
    return regions

def all_clients(region):
    """Construct all per-region clients"""
    cred = credentials.client.Client()
    conf = openapi.models.Config(credential=cred, region_id=region)
    clients = Clients(
        region=region,
        ecs=ecs.client.Client(conf),
        vpc=vpc.client.Client(conf),
    )
    return clients

def ram_client():
    """Construct resource access management client"""
    cred = credentials.client.Client()
    conf = openapi.models.Config(credential=cred, endpoint=RAM_ENDPOINT)
    client = ram.client.Client(conf)
    return client

def setup_censorship_bypass_role(client):
    """Set up censorship bypass role (required for importing images)"""
    role_name = IPXE_CENSORSHIP_BYPASS_ROLE_NAME
    assume_policy = json.dumps(IPXE_CENSORSHIP_BYPASS_ROLE_ASSUME_POLICY)
    req = ram.models.GetRoleRequest(
        role_name=role_name,
    )
    try:
        rsp = client.get_role(req)
        arn = rsp.body.role.arn
    except openapi.exceptions.ClientException as exc:
        if exc.code != 'EntityNotExist.Role':
            raise
        req = ram.models.CreateRoleRequest(
            role_name=role_name,
            assume_role_policy_document=assume_policy,
        )
        rsp = client.create_role(req)
        arn = rsp.body.role.arn
    req = ram.models.UpdateRoleRequest(
        role_name=role_name,
        new_assume_role_policy_document=assume_policy,
        new_description="iPXE role to help bypass OSS censorship restrictions",
    )
    rsp = client.update_role(req)
    req = ram.models.AttachPolicyToRoleRequest(
        role_name=role_name,
        policy_type='System',
        policy_name='AliyunOSSFullAccess',
    )
    try:
        rsp = client.attach_policy_to_role(req)
    except openapi.exceptions.ClientException as exc:
        if exc.code != 'EntityAlreadyExists.Role.Policy':
            raise
    return arn

def setup_vpc(clients):
    """Set up VPC"""
    tag = vpc.models.DescribeVpcsRequestTag(
        key=IPXE_VPC_TAG,
        value=IPXE_VPC_TAG,
    )
    req = vpc.models.DescribeVpcsRequest(
        region_id=clients.region,
        tag=[tag],
    )
    rsp = clients.vpc.describe_vpcs(req)
    vpcs = rsp.body.vpcs.vpc or []
    if vpcs:
        assert len(vpcs) == 1
        vpc_id = vpcs[0].vpc_id
    if not vpcs:
        tag = vpc.models.CreateVpcRequestTag(
            key=IPXE_VPC_TAG,
            value=IPXE_VPC_TAG,
        )
        req = vpc.models.CreateVpcRequest(
            region_id=clients.region,
            tag=[tag],
        )
        rsp = clients.vpc.create_vpc(req)
        vpc_id = rsp.body.vpc_id
    while True:
        time.sleep(1)
        req = vpc.models.DescribeVpcsRequest(
            region_id=clients.region,
            vpc_id=vpc_id,
        )
        rsp = clients.vpc.describe_vpcs(req)
        status = rsp.body.vpcs.vpc[0].status
        if status != 'Pending':
            break
    if status != 'Available':
        raise RuntimeError(status)
    req = vpc.models.ModifyVpcAttributeRequest(
        region_id=clients.region,
        vpc_id=vpc_id,
        vpc_name=("%s-%s" % (IPXE_VPC_TAG, clients.region)),
        description="Default VPC for iPXE development and testing",
    )
    rsp = clients.vpc.modify_vpc_attribute(req)
    req = vpc.models.ModifyVpcAttributeRequest(
        region_id=clients.region,
        vpc_id=vpc_id,
        enable_ipv_6=True,
    )
    try:
        rsp = clients.vpc.modify_vpc_attribute(req)
    except openapi.exceptions.ClientException as exc:
        # AliCloud provides no other way to detect regions without IPv6 support
        if exc.code != 'OperationUnsupported.Ipv6Feature':
            raise
    return vpc_id

def setup_vswitch(clients, vpc_id, zone_id, index):
    """Set up vSwitch"""
    tag = vpc.models.DescribeVSwitchesRequestTag(
        key=IPXE_VSWITCH_TAG,
        value=IPXE_VSWITCH_TAG,
    )
    req = vpc.models.DescribeVSwitchesRequest(
        region_id=clients.region,
        zone_id=zone_id,
        tag=[tag],
    )
    rsp = clients.vpc.describe_vswitches(req)
    vswitches = rsp.body.v_switches.v_switch or []
    if vswitches:
        assert len(vswitches) == 1
        assert vswitches[0].vpc_id == vpc_id
        vswitch_id = vswitches[0].v_switch_id
    else:
        req = vpc.models.DescribeVpcsRequest(
            region_id=clients.region,
            vpc_id=vpc_id,
        )
        rsp = clients.vpc.describe_vpcs(req)
        ipv6_cidr_block = index if rsp.body.vpcs.vpc[0].enabled_ipv_6 else None
        ipv4net = ipaddress.ip_network(rsp.body.vpcs.vpc[0].cidr_block)
        ipv4subnet = next(islice(ipv4net.subnets(new_prefix=24), index, None))
        cidr_block = str(ipv4subnet)
        tag = vpc.models.CreateVSwitchRequestTag(
            key=IPXE_VSWITCH_TAG,
            value=IPXE_VSWITCH_TAG,
        )
        req = vpc.models.CreateVSwitchRequest(
            region_id=clients.region,
            vpc_id=vpc_id,
            zone_id=zone_id,
            tag=[tag],
            cidr_block=cidr_block,
            ipv_6cidr_block=ipv6_cidr_block,
        )
        try:
            rsp = clients.vpc.create_vswitch(req)
            vswitch_id = rsp.body.v_switch_id
        except openapi.exceptions.ClientException as exc:
            # AliCloud provides no other way to detect disabled zones
            if exc.code != 'OperationDenied.ZoneIsDisabled':
                raise
            vswitch_id = None
    if vswitch_id:
        while True:
            time.sleep(1)
            req = vpc.models.DescribeVSwitchesRequest(
                region_id=clients.region,
                v_switch_id=vswitch_id,
            )
            rsp = clients.vpc.describe_vswitches(req)
            status = rsp.body.v_switches.v_switch[0].status
            if status != 'Pending':
                break
        if status != 'Available':
            raise RuntimeError(status)
        req = vpc.models.ModifyVSwitchAttributeRequest(
            region_id=clients.region,
            v_switch_id=vswitch_id,
            v_switch_name=('%s-%s' % (IPXE_VSWITCH_TAG, zone_id)),
            description="Default vSwitch for iPXE development and testing",
        )
        rsp = clients.vpc.modify_vswitch_attribute(req)
    return vswitch_id

def setup_vswitches(clients, vpc_id):
    """Set up vSwitches"""
    req = vpc.models.DescribeZonesRequest(region_id=clients.region)
    rsp = clients.vpc.describe_zones(req)
    vswitch_ids = [setup_vswitch(clients, vpc_id, zone.zone_id, index)
                   for index, zone in enumerate(rsp.body.zones.zone or [])]
    return sorted(filter(None, vswitch_ids))

def setup_sg(clients, vpc_id):
    """Set up security group"""
    tag = ecs.models.DescribeSecurityGroupsRequestTag(
        key=IPXE_SG_TAG,
        value=IPXE_SG_TAG,
    )
    req = ecs.models.DescribeSecurityGroupsRequest(
        region_id=clients.region,
        vpc_id=vpc_id,
        tag=[tag],
    )
    rsp = clients.ecs.describe_security_groups(req)
    sgs = rsp.body.security_groups.security_group or []
    if sgs:
        assert len(sgs) == 1
        assert sgs[0].vpc_id == vpc_id
        sg_id = sgs[0].security_group_id
    else:
        tag = ecs.models.CreateSecurityGroupRequestTag(
            key=IPXE_SG_TAG,
            value=IPXE_SG_TAG,
        )
        req = ecs.models.CreateSecurityGroupRequest(
            region_id=clients.region,
            vpc_id=vpc_id,
            tag=[tag],
        )
        rsp = clients.ecs.create_security_group(req)
        sg_id = rsp.body.security_group_id
    req = ecs.models.ModifySecurityGroupAttributeRequest(
        region_id=clients.region,
        security_group_id=sg_id,
        security_group_name=IPXE_SG_TAG,
        description="Default security group for iPXE development and testing",
    )
    rsp = clients.ecs.modify_security_group_attribute(req)
    perm4 = ecs.models.AuthorizeSecurityGroupEgressRequestPermissions(
        policy='accept',
        dest_cidr_ip='0.0.0.0/0',
        ip_protocol='ALL',
        port_range='-1/-1',
    )
    perm6 = ecs.models.AuthorizeSecurityGroupEgressRequestPermissions(
        policy='accept',
        ipv_6dest_cidr_ip='::/0',
        ip_protocol='ALL',
        port_range='-1/-1',
    )
    req = ecs.models.AuthorizeSecurityGroupEgressRequest(
        region_id=clients.region,
        security_group_id=sg_id,
        permissions=[perm4, perm6],
    )
    rsp = clients.ecs.authorize_security_group_egress(req)
    return sg_id

def setup_region(clients):
    """Set up region"""
    vpc_id = setup_vpc(clients)
    vswitch_ids = setup_vswitches(clients, vpc_id)
    sg_id = setup_sg(clients, vpc_id)
    return (sg_id, vpc_id, vswitch_ids)

# Parse command-line arguments
parser = argparse.ArgumentParser(description="Set up Alibaba Cloud defaults")
parser.add_argument('--region', '-r', action='append',
                    help="AliCloud region(s)")
parser.add_argument('--create-role', action=argparse.BooleanOptionalAction,
                    default=True, help="Create censorship bypass role")
args = parser.parse_args()

# Set up censorship bypass role
if args.create_role:
    arn = setup_censorship_bypass_role(ram_client())

# Use all regions if none specified
if not args.region:
    args.region = all_regions()

# Construct per-region clients
clients = {region: all_clients(region) for region in args.region}

# Set up each region
with ThreadPoolExecutor(max_workers=len(args.region)) as executor:
    futures = {executor.submit(setup_region,
                               clients=clients[region]): region
               for region in args.region}
    results = {futures[x]: x.result() for x in as_completed(futures)}

# Show created resources
if args.create_role:
    print("%s" % arn)
for region in args.region:
    (sg_id, vpc_id, vswitch_ids) = results[region]
    print("%s %s %s %s" % (region, sg_id, vpc_id, " ".join(vswitch_ids)))
