Source code for hca.util

This module contains utilities for the DCP CLI and API bindings. These
utility classes and functions are used under the hood by both the DCP
API bindings library and by its CLI. There is no need to use these
utilities directly unless you are extending DCP client functionality.

``SwaggerClient`` is a base class for a general purpose Swagger API
client connection manager. User classes such as ``hca.dss.DSSClient``
extend it as follows:

  class APIClient(SwaggerClient):
      def __init__(self, *args, **kwargs):
          super(APIClient, self).__init__(*args, **kwargs)
          self.commands += [self.special_cli_command]

      def special_cli_command(self, required_argument, optional_argument=""):
          return {}

Each user class should have a configuration subtree keyed by its name
(such as ``APIClient`` above) under the DCP-wide config manager
(available via ``hca.get_config()``; the static defaults for the
config manager are stored in ``hca/default_config.json``). Within that
subtree, the key ``swagger_url`` should point to the HTTPS URL
containing the Swagger API definition that the client is providing an
interface for. On first use, this API definition will be downloaded
and saved into the user config directory (for example,
/Users/Alice/.config/hca) with a name determined by the base64
encoding of ``swagger_url``. On subsequent uses, this file will be
loaded instead to get the Swagger API definition.

The Swagger API definition is then used to dynamically construct and
attach API client methods as class methods, decorate them with API
metadata such as docstrings and I/O signatures. Method names are
determined using a mapping heuristic:

  GET /foo/bar -> APIClient.get_foo_bar()
  POST /widgets/{id} -> APIClient.post_widget()

Client methods (provided by _ClientMethodFactory) build the HTTP
request payload by matching their keyword argument inputs to the
Swagger API definition, and use the ``requests`` library to call the
API. JSON body input and output is assumed by default:

  json_results = APIClient().post_widget(id="foo", qs_param="x", body_param=123)

The ``stream()`` method can be used instead to stream the raw body as
bytes, or to otherwise provide access to the ``requests.Response``

  with APIClient() as response:
      while True:
          chunk =
          if not chunk:

Results from API routes that support GitHub/RFC 5988 style pagination
can be paged like this:

  for result in APIClient().get_foo_bar.iterate():

Routes that require authentication trigger the use of the auth
middleware provided by requests_oauthlib (work in progress).

CLI parsers for argparse can be generated and injected into a parent
``argparse.ArgumentParser`` object by passing a subparsers object to
SwaggerClient.build_argparse_subparsers. The resulting CLI entry point
can be called like this:

  $ dss post_widget --id ID

In addition to bindings to API methods in the Swagger definition,
SwaggerClient designates certain methods as *commands*, which means
they are part of the public bindings API and are also provided as CLI
subcommands. The SwaggerClient class provides two such commands, login
and logout, which manage the cached authentication credentials for the
client. Subclasses can add more commands by adding them to the
``SwaggerClient.commands`` array, as shown with
``special_cli_command`` in the example above.

import os
import multiprocessing
import types
import collections
import typing
import json
import errno
import base64
import argparse
import time
import jwt
import requests
import jmespath

from inspect import signature, Parameter
from requests.adapters import HTTPAdapter, DEFAULT_POOLSIZE
from requests_oauthlib import OAuth2Session
from urllib3.util import retry, timeout
from urllib.parse import urljoin
from jsonpointer import resolve_pointer
from threading import Lock
from argparse import RawTextHelpFormatter
from dcplib.networking import Session

from .. import get_config, logger
from .exceptions import SwaggerAPIException, SwaggerClientInternalError
from ._docs import _pagination_docstring, _streaming_docstring, _md2rst, _parse_docstring
from .fs_helper import FSHelper as fs

"""Based on
DEFAULT_THREAD_COUNT = multiprocessing.cpu_count() * 2

class RetryPolicy(retry.Retry):

class _ClientMethodFactory(object):
    def __init__(self, client, parameters, path_parameters, http_method, method_name, method_data, body_props):
        self._context_manager_response = None

    def _request(self, req_args, url=None, stream=False, headers=None):
        supplied_path_params = [p for p in req_args if p in self.path_parameters and req_args[p] is not None]
        if url is None:
            url = + self.client.http_paths[self.method_name][frozenset(supplied_path_params)]
            url = url.format(**req_args)
        logger.debug("%s %s %s", self.http_method, url, req_args)
        query = {k: v for k, v in req_args.items()
                 if self.parameters.get(k, {}).get("in") == "query" and v is not None}
        body = {k: v for k, v in req_args.items() if k in self.body_props and v is not None}
        if "security" in self.method_data:
            session = self.client.get_authenticated_session()
            session = self.client.get_session()

        json_input = body if self.body_props else None
        headers = headers or {}
        headers.update({k: v for k, v in req_args.items() if self.parameters.get(k, {}).get('in') == 'header'})
        res = session.request(self.http_method, url, params=query, json=json_input, stream=stream,
                              headers=headers, timeout=self.client.timeout_policy)
        if res.status_code >= 400:
            raise SwaggerAPIException(response=res)
        return res

    def _consume_response(self, response):
        if self.http_method.upper() == "HEAD":
            return response
        elif response.headers["content-type"].startswith("application/json"):
            return response.json()
            return response.content

    def __call__(self, client, **kwargs):
        return self._consume_response(self._request(kwargs))

    def _cli_call(self, cli_args):
        return self._consume_response(self._request(vars(cli_args)))

    def stream(self, **kwargs):
        self._context_manager_response = self._request(kwargs, stream=True)
        return self

    def __enter__(self, **kwargs):
        assert self._context_manager_response is not None
        return self._context_manager_response

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._context_manager_response = None

class _PaginatingClientMethodFactory(_ClientMethodFactory):
    def _get_raw_pages(self, **kwargs):
        page = None
        while page is None or page.links.get("next", {}).get("url"):
            page = self._request(kwargs, url=page.links["next"]["url"] if page else None)
            yield page

    def iterate(self, **kwargs):
        Yield specific items from each response depending on its contents.

        For example, GET /bundles/{id} and GET /collections/{id} yield the
        items contained within; POST /search yields search result items.
        for page in self._get_raw_pages(**kwargs):
            content_key = page.headers.get("X-OpenAPI-Paginated-Content-Key", "results")
            results = page.json()
            for key in content_key.split("."):
                results = results[key]
            for result in results:
                yield result

    def paginate(self, **kwargs):
        """Yield paginated responses one response body at a time."""
        for page in self._get_raw_pages(**kwargs):
            yield page.json()

    def _cli_call(self, cli_args):
        if cli_args.paginate is not True:
            return super()._cli_call(cli_args)
        return self._auto_page(**vars(cli_args))

    def _auto_page(self, **kwargs):
        '''This method allows for autopaging in commands and bindings'''
        response_data = None
        for page in self._get_raw_pages(**kwargs):
            page_data = page.json()
            content_key = page.headers.get("X-OpenAPI-Paginated-Content-Key", "results")
            if response_data is None:
                response_data = page_data
                data_aggregrator =, response_data)
                patch =, page_data)
                if patch:
                    data_aggregrator += patch
                response_data[content_key] = data_aggregrator
        return response_data

class SwaggerClient(object):
    scheme = "https"
    retry_policy = RetryPolicy(read=10,
                               status_forcelist=frozenset({500, 502, 503, 504}))
    token_expiration = 3600
    _authenticated_session = None
    _session = None
    _spec_valid_for_days = 7
    _swagger_spec_lock = Lock()
    _type_map = {
        "string": str,
        "number": float,
        "integer": int,
        "boolean": bool,
        "array": typing.List,
        "object": typing.Mapping
    _audience = ""  # TODO derive from swagger
    # The read timeout should be longer than DSS' API Gateway timeout to avoid races with the client and the gateway
    # hanging up at the same time. It's better to consistently get a 504 from the server than a read timeout from the
    # client or sometimes one and sometimes the other.
    timeout_policy = timeout.Timeout(connect=20, read=40)
    max_redirects = 1024

    def __init__(self, config=None, swagger_url=None, **session_kwargs):
        self.config = config or get_config()
        self.swagger_url = swagger_url or self.config[self.__class__.__name__].swagger_url
        self._session_kwargs = session_kwargs
        self._swagger_spec = None

        self.__class__.__doc__ = _md2rst(self.swagger_spec["info"]["description"])
        self.methods = {}
        self.commands = [self.login, self.logout]
        self.http_paths = collections.defaultdict(dict)
        if "openapi" in self.swagger_spec:
            server = self.swagger_spec["servers"][0]
            variables = {k: v["default"] for k, v in server.get("variables", {}).items()}
   = server["url"].format(**variables)
   = "{scheme}://{host}{base}".format(scheme=self.scheme,
        for http_path, path_data in self.swagger_spec["paths"].items():
            for http_method, method_data in path_data.items():
                self._build_client_method(http_method, http_path, method_data)

    def load_swagger_json(swagger_json, ptr_str="$ref"):
        Load the Swagger JSON and resolve {"$ref": "#/..."} internal JSON Pointer references.
        refs = []

        def store_refs(d):
            if len(d) == 1 and ptr_str in d:
            return d

        swagger_content = json.load(swagger_json, object_hook=store_refs)
        for ref in refs:
            _, target = ref.popitem()
            assert target[0] == "#"
            ref.update(resolve_pointer(swagger_content, target[1:]))
        return swagger_content

    def swagger_spec(self):
        with self._swagger_spec_lock:
            if not self._swagger_spec:
                if "swagger_filename" in self.config:
                    swagger_filename = self.config.swagger_filename
                    if not swagger_filename.startswith("/"):
                        swagger_filename = os.path.join(os.path.dirname(__file__), swagger_filename)
                    swagger_filename = self._get_swagger_filename(self.swagger_url)
                if (("swagger_filename" not in self.config) and
                    ((not os.path.exists(swagger_filename)) or
                     (fs.get_days_since_last_modified(swagger_filename) >= self._spec_valid_for_days))):
                    except OSError as e:
                        if not (e.errno == errno.EEXIST and os.path.isdir(self.config.user_config_dir)):
                    res = self.get_session().get(self.swagger_url)
                    res_json = res.json()
                    assert "swagger" in res_json or "openapi" in res_json
                    fs.atomic_write(swagger_filename, res.content)
                with open(swagger_filename) as fh:
                    self._swagger_spec = self.load_swagger_json(fh)
        return self._swagger_spec

    def _get_swagger_filename(self, swagger_url):
        swagger_filename = base64.urlsafe_b64encode(swagger_url.encode()).decode() + ".json"
        swagger_filename = os.path.join(self.config.user_config_dir, swagger_filename)
        return swagger_filename

    def clear_cache(self):
        Clear the cached API definitions for a component. This can help resolve errors communicating with the API.
        except EnvironmentError as e:

    def application_secrets(self):
        if "application_secrets" not in self.config:
            app_secrets_url = "https://{}/internal/application_secrets".format(self._swagger_spec["host"])
            self.config.application_secrets = requests.get(app_secrets_url).json()
        return self.config.application_secrets

    def get_session(self):
        if self._session is None:
            self._session = Session(**self._session_kwargs)
            self._session.max_redirects = self.max_redirects
            self._session.headers.update({"User-Agent": self.__class__.__name__})
        return self._session

    def logout(self):
        Clear {prog} authentication credentials previously configured with ``{prog} login``.
        for keys in ["application_secrets", "oauth2_token"]:
                del self.config[keys]
            except KeyError:

    def login(self, access_token="", remote=False):
        Configure and save {prog} authentication credentials.

        This command may open a browser window to ask for your
        consent to use web service authentication credentials.

        Use --remote if using the CLI in a remote environment
        if access_token:
            credentials = argparse.Namespace(token=access_token, refresh_token=None, id_token=None)
            scopes = ["openid", "email", "offline_access"]
            if remote:
                import google_auth_oauthlib.flow
                application_secrets = self.application_secrets
                redirect_uri = urljoin(application_secrets['installed']['auth_uri'], "/echo")
                flow = google_auth_oauthlib.flow.Flow.from_client_config(self.application_secrets, scopes=scopes,

                authorization_url, _ = flow.authorization_url()
                print("please authenticate at the url: {}".format(authorization_url))
                code = input("pass 'code' value from within query_params: ")
                credentials = flow.credentials

                from google_auth_oauthlib.flow import InstalledAppFlow
                flow = InstalledAppFlow.from_client_config(self.application_secrets, scopes=scopes)
                msg = "Authentication successful. Please close this tab and run HCA CLI commands in the terminal."
                credentials = flow.run_local_server(success_message=msg, audience=self._audience)

        # TODO: (akislyuk) test token autorefresh on expiration
        self.config.oauth2_token = dict(access_token=credentials.token,
        print("Storing access credentials")

    def _get_oauth_token_from_service_account_credentials(self):
        scopes = [""]
        assert 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ
        from google.auth.transport.requests import Request as GoogleAuthRequest
        from google.oauth2.service_account import Credentials as ServiceAccountCredentials"Found GOOGLE_APPLICATION_CREDENTIALS environment variable. "
                    "Using service account credentials for authentication.")
        service_account_credentials_filename = os.environ['GOOGLE_APPLICATION_CREDENTIALS']

        if not os.path.isfile(service_account_credentials_filename):
            msg = 'File "{}" referenced by the GOOGLE_APPLICATION_CREDENTIALS environment variable does not exist'
            raise Exception(msg.format(service_account_credentials_filename))

        credentials = ServiceAccountCredentials.from_service_account_file(
        r = GoogleAuthRequest()
        return credentials.token, credentials.expiry

    def _get_jwt_from_service_account_credentials(self):
        assert 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ
        service_account_credentials_filename = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
        if not os.path.isfile(service_account_credentials_filename):
            msg = 'File "{}" referenced by the GOOGLE_APPLICATION_CREDENTIALS environment variable does not exist'
            raise Exception(msg.format(service_account_credentials_filename))
        with open(service_account_credentials_filename) as fh:
            service_credentials = json.load(fh)

        iat = time.time()
        exp = iat + self.token_expiration
        payload = {'iss': service_credentials["client_email"],
                   'sub': service_credentials["client_email"],
                   'aud': self._audience,
                   'iat': iat,
                   'exp': exp,
                   'email': service_credentials["client_email"],
                   'scope': ['email', 'openid', 'offline_access'],
                   '': 'hca',
                   '': service_credentials["client_email"]
        additional_headers = {'kid': service_credentials["private_key_id"]}
        signed_jwt = jwt.encode(payload, service_credentials["private_key"], headers=additional_headers,
        return signed_jwt, exp

    def expired_token(self):
        """Return True if we have an active session containing an expired (or nearly expired) token."""
        ten_second_buffer = 10
        if self._authenticated_session:
            token_expiration = self._authenticated_session.token['expires_at']
            if token_expiration:
                if token_expiration <= time.time() + ten_second_buffer:
                    return True
        return False

    def get_authenticated_session(self):
        if self._authenticated_session is None or self.expired_token():
            oauth2_client_data = self.application_secrets["installed"]
            if 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
                token, expires_at = self._get_jwt_from_service_account_credentials()
                self._authenticated_session = OAuth2Session(client_id=oauth2_client_data["client_id"],
                if "oauth2_token" not in self.config:
                    msg = ('Please configure {prog} authentication credentials using "{prog} login" '
                           'or set the GOOGLE_APPLICATION_CREDENTIALS environment variable')
                    raise Exception(msg.format(prog=self.__module__.replace(".", " ")))
                self._authenticated_session = OAuth2Session(
            self._authenticated_session.headers.update({"User-Agent": self.__class__.__name__})
        return self._authenticated_session

    def _set_retry_policy(self, session):
        adapter = HTTPAdapter(max_retries=self.retry_policy, pool_maxsize=max(DEFAULT_THREAD_COUNT, DEFAULT_POOLSIZE))
        session.mount('http://', adapter)
        session.mount('https://', adapter)

    def _save_auth_token_refresh_result(self, result):
        self.config.oauth2_token = result

    def _process_method_args(self, parameters, body_json_schema):
        body_props = {}
        method_args = collections.OrderedDict()

        def _parse_properties(properties, schema):
            for prop_name, prop_data in properties.items():
                enum_values = prop_data.get("enum")
                type_ = prop_data.get("type") if enum_values is None else 'string'
                anno = self._type_map[type_]
                if prop_name not in body_json_schema.get("required", []):
                    anno = typing.Optional[anno]
                param = Parameter(prop_name, Parameter.POSITIONAL_OR_KEYWORD, default=prop_data.get("default"),
                method_args.setdefault(prop_name, {}).update(param=param,
                                                             required=prop_name in body_json_schema.get("required", []))
                body_props[prop_name] = _merge_dict(schema, body_props.get('prop_name', {}))

        if body_json_schema.get('properties', {}):
            _parse_properties(body_json_schema["properties"], body_json_schema)
        for schema in body_json_schema.get('allOf', []):
            _parse_properties(schema.get('properties', {}), schema)

        for parameter in parameters.values():
            annotation = str if parameter.get("required") else typing.Optional[str]
            param = Parameter(parameter["name"], Parameter.POSITIONAL_OR_KEYWORD, default=parameter.get("default"),
            method_args[parameter["name"]] = dict(param=param, doc=parameter.get("description"),
                                                  choices=parameter.get("enum"), required=parameter.get("required"))
        return body_props, method_args

    def _build_method_name(http_method, http_path):
        method_name = http_path.replace('/.well-known', '').replace('-', '_')
        method_name_parts = [http_method] + [p for p in method_name.split("/")[1:] if not p.startswith("{")]
        method_name = "_".join(method_name_parts)
        if method_name.endswith("s") and (http_method.upper() in {"POST", "PUT"} or http_path.endswith("/{uuid}")):
            method_name = method_name[:-1]
        return method_name

    def _build_client_method(self, http_method, http_path, method_data):
        method_name = self._build_method_name(http_method, http_path)
        parameters = {p["name"]: p for p in method_data.get("parameters", [])}
        body_json_schema = {"properties": {}}
        if "requestBody" in method_data and "application/json" in method_data["requestBody"]["content"]:
            body_json_schema = method_data["requestBody"]["content"]["application/json"]["schema"]
            for p in parameters:
                if parameters[p]["in"] == "body":
                    body_json_schema = parameters.pop(p)["schema"]

        path_parameters = [p_name for p_name, p_data in parameters.items() if p_data["in"] == "path"]
        self.http_paths[method_name][frozenset(path_parameters)] = http_path

        body_props, method_args = self._process_method_args(parameters=parameters, body_json_schema=body_json_schema)

        method_supports_pagination = True if str( in method_data["responses"] else False
        highlight_streaming_support = True if str( in method_data["responses"] else False

        factory = _PaginatingClientMethodFactory if method_supports_pagination else _ClientMethodFactory
        client_method = factory(self, parameters, path_parameters, http_method, method_name, method_data, body_props)
        client_method.__name__ = method_name
        client_method.__qualname__ = self.__class__.__name__ + "." + method_name

        params = [Parameter("factory", Parameter.POSITIONAL_OR_KEYWORD),
                  Parameter("client", Parameter.POSITIONAL_OR_KEYWORD)]
        params += [v["param"] for k, v in method_args.items() if not k.startswith("_")]
        client_method.__signature__ = signature(client_method).replace(parameters=params)
        docstring = method_data.get("summary", '') + "\n\n"

        if method_supports_pagination:
            docstring += _pagination_docstring.format(client_name=self.__class__.__name__, method_name=method_name)

        if highlight_streaming_support:
            docstring += _streaming_docstring.format(client_name=self.__class__.__name__, method_name=method_name)

        for param in method_args:
            if not param.startswith("_"):
                param_doc = _md2rst(method_args[param]["doc"] or "")
                docstring += ":param {}: {}\n".format(param, param_doc.replace("\n", " "))
                docstring += ":type {}: {}\n".format(param, method_args[param]["param"].annotation)
        docstring += "\n\n" + _md2rst(method_data.get("description", ''))
        client_method.__doc__ = docstring

        setattr(self.__class__, method_name, types.MethodType(client_method, SwaggerClient))
        self.methods[method_name] = dict(method_data, entry_point=getattr(self, method_name)._cli_call,
                                         signature=client_method.__signature__, args=method_args)

    def _command_arg_forwarder_factory(self, command, command_sig):
        def arg_forwarder(parsed_args):
            command_args = {k: v for k, v in vars(parsed_args).items() if k in command_sig.parameters}
            return command(**command_args)
        return arg_forwarder

    def _get_command_arg_settings(self, param_data):
        if param_data.default is Parameter.empty:
            return dict(required=True)
        elif param_data.default is True:
            return dict(action='store_false', default=True)
        elif param_data.default is False:
            return dict(action='store_true', default=False)
        elif isinstance(param_data.default, (list, tuple)):
            return dict(nargs="+", default=param_data.default)
            return dict(type=type(param_data.default), default=param_data.default)

    def _get_param_argparse_type(self, anno):
        if anno in {typing.List, typing.Mapping, typing.Union[typing.Mapping, None]}:
            return json.loads
        elif isinstance(getattr(anno, "__args__", None), tuple) and anno == typing.Optional[anno.__args__[0]]:
            return anno.__args__[0]
        return anno

    def build_argparse_subparsers(self, subparsers, help_menu=False):
        for method_name, method_data in self.methods.items():
            subcommand_name = method_name.replace("_", "-")
            subparser = subparsers.add_parser(subcommand_name,
            if help_menu:
                required_group_parser = subparser.add_argument_group('Required Arguments')
            for param_name, param in method_data["signature"].parameters.items():
                if param_name in {"client", "factory"}:
                logger.debug("Registering %s %s %s", method_name, param_name, param.annotation)
                nargs = "+" if param.annotation == typing.List else None
                if help_menu:
                    subparser = required_group_parser if method_data["args"][param_name]["required"] else subparser
                subparser.add_argument("--" + param_name.replace("_", "-").replace("/", "-"),
            if str( in method_data["responses"]:
                subparser.add_argument("--no-paginate", action="store_false", dest="paginate",
                                       help='Do not automatically page the responses', default=True)

        for command in self.commands:
            sig = signature(command)
            if not getattr(command, "__doc__", None):
                raise SwaggerClientInternalError("Command {} has no docstring".format(command))
            docstring = command.__doc__.format(prog=subparsers._prog_prefix)
            method_args = _parse_docstring(docstring)
            command_subparser = subparsers.add_parser(command.__name__.replace("_", "-"),
            if help_menu:
                required_group_parser = command_subparser.add_argument_group('Required Arguments')
            for param_name, param_data in sig.parameters.items():
                params = self._get_command_arg_settings(param_data)
                if help_menu:
                    command_subparser = required_group_parser if params.get('required', False) else command_subparser
                command_subparser.add_argument("--" + param_name.replace("_", "-"),
                                               help=method_args['params'].get(param_name, None),
            command_subparser.set_defaults(entry_point=self._command_arg_forwarder_factory(command, sig))

def _merge_dict(source, destination):
    """Recursive dict merge"""
    for key, value in source.items():
        if isinstance(value, dict):
            node = destination.setdefault(key, {})
            _merge_dict(value, node)
            destination[key] = value
    return destination