Source code for gen3.auth

import base64
import hashlib
import json
from requests.auth import AuthBase
import os
import random
import requests
import time
from cdislogging import get_logger

from urllib.parse import urlparse
import backoff

from gen3.utils import (
    DEFAULT_BACKOFF_SETTINGS,
    raise_for_status_and_print_error,
    remove_trailing_whitespace_and_slashes_in_url,
)

logging = get_logger("__name__")


class Gen3AuthError(Exception):
    pass


def decode_token(token_str):
    """
    jq -r '.api_key' < ~/.gen3/qa-covid19.planx-pla.net.json | awk -F . '{ print $2 }' | base64 --decode | jq -r .
    """
    tokenParts = token_str.split(".")
    if len(tokenParts) < 3:
        raise Exception("Invalid JWT. Could not split into parts.")
    padding = "===="
    infoStr = tokenParts[1] + padding[0 : len(tokenParts[1]) % 4]
    jsonStr = base64.urlsafe_b64decode(infoStr)
    return json.loads(jsonStr)


def endpoint_from_token(token_str):
    """
    Extract the endpoint from a JWT issue ("iss" property)
    """
    info = decode_token(token_str)
    urlparts = urlparse(info["iss"])
    endpoint = urlparts.scheme + "://" + urlparts.hostname
    if urlparts.port:
        endpoint += ":" + str(urlparts.port)
    return remove_trailing_whitespace_and_slashes_in_url(endpoint)


def _handle_access_token_response(resp, token_key):
    """
    Shared helper for both get_access_token_with_key and get_access_token_from_wts
    """
    err_msg = "Failed to get an access token from {}:\n{}"
    if resp.status_code != 200:
        raise Gen3AuthError(err_msg.format(resp.url, resp.text))
    try:
        json_resp = resp.json()
        return json_resp[token_key]
    except ValueError:  # cannot parse JSON
        raise Gen3AuthError(err_msg.format(resp.url, resp.text))
    except KeyError:  # no access_token in JSON response
        raise Gen3AuthError(err_msg.format(resp.url, json_resp))


def get_access_token_with_key(api_key):
    """
    Try to fetch an access token given the api key
    """
    endpoint = endpoint_from_token(api_key["api_key"])
    # attempt to get a token from Fence
    auth_url = "{}/user/credentials/cdis/access_token".format(endpoint)
    resp = requests.post(auth_url, json=api_key)
    token_key = "access_token"
    return _handle_access_token_response(resp, token_key)


def get_access_token_with_client_credentials(endpoint, client_credentials, scopes):
    """
    Try to get an access token from Fence using client credentials

    Args:
        endpoint (str): URL of the Gen3 instance to get an access token for
        client_credentials ((str, str) tuple): (client ID, client secret) tuple
        scopes (str): space-delimited list of scopes to request
    """
    if not endpoint:
        raise ValueError("'endpoint' must be specified when using client credentials")
    url = f"{endpoint}/user/oauth2/token?grant_type=client_credentials&scope={scopes}"
    resp = requests.post(url, auth=client_credentials)
    return _handle_access_token_response(resp, "access_token")


def get_wts_endpoint(namespace=os.getenv("NAMESPACE", "default")):
    return "http://workspace-token-service.{}.svc.cluster.local".format(namespace)


def get_wts_idps(namespace=os.getenv("NAMESPACE", "default"), external_wts_host=None):
    wts_url = None
    if external_wts_host == None:
        wts_url = get_wts_endpoint(namespace)
    else:
        wts_url = external_wts_host
    url = wts_url.rstrip("/") + "/external_oidc/"
    resp = requests.get(url)
    raise_for_status_and_print_error(resp)
    return resp.json()


def get_token_cache_file_name(key):
    """Compute the path to the access-token cache file"""
    cache_folder = "{}/.cache/gen3/".format(os.path.expanduser("~"))
    if not os.path.isdir(cache_folder):
        os.makedirs(cache_folder)

    cache_prefix = cache_folder + "token_cache_"
    s = hashlib.sha256()
    s.update(key.encode("utf-8"))
    return cache_prefix + s.hexdigest()


[docs] class Gen3Auth(AuthBase): """Gen3 auth helper class for use with requests auth. Implements requests.auth.AuthBase in order to support JWT authentication. Generates access tokens from the provided refresh token file or string. Automatically refreshes access tokens when they expire. Args: refresh_file (str, opt): The file containing the downloaded JSON web token. Optional if working in a Gen3 Workspace. Defaults to (env["GEN3_API_KEY"] || "credentials") if refresh_token and idp not set. Includes ~/.gen3/ in search path if value does not include /. Interprets "idp://wts/<idp>" as an idp. Interprets "accesstoken:///<token>" as an access token refresh_token (str, opt): The JSON web token. Optional if working in a Gen3 Workspace. idp (str, opt): If working in a Gen3 Workspace, the IDP to use can be specified - "local" indicates the local environment fence idp client_credentials (tuple, opt): The (client_id, client_secret) credentials for an OIDC client that has the 'client_credentials' grant, allowing it to obtain access tokens. client_scopes (str, opt): Space-separated list of scopes requested for access tokens obtained from client credentials. Default: "user data openid" access_token (str, opt): provide an access token to override the use of any API key/refresh token. This is intended for cases where you may want to pass a token that was issued to a particular OIDC client (rather than acting on your own direct behalf, you can provide a key that was client-issued to your user) Examples: This generates the Gen3Auth class pointed at the sandbox commons while using the credentials.json downloaded from the commons profile page and installed in ~/.gen3/credentials.json >>> auth = Gen3Auth() or use ~/.gen3/crdc.json: >>> auth = Gen3Auth(refresh_file="crdc") or use some arbitrary file: >>> auth = Gen3Auth(refresh_file="./key.json") or set the GEN3_API_KEY environment variable rather than pass the refresh_file argument to the Gen3Auth constructor. If working with an OIDC client that has the 'client_credentials' grant, allowing it to obtain access tokens, provide the client ID and secret: Note: client secrets should never be hardcoded! >>> auth = Gen3Auth( endpoint="https://datacommons.example", client_credentials=("client ID", os.environ["GEN3_OIDC_CLIENT_CREDS_SECRET"]) ) If working in a Gen3 Workspace, initialize as follows: >>> auth = Gen3Auth() """ def __init__( self, endpoint=None, refresh_file=None, refresh_token=None, idp=None, client_credentials=None, client_scopes=None, access_token=None, ): logging.debug("Initializing auth..") self.endpoint = remove_trailing_whitespace_and_slashes_in_url(endpoint) # note - `_refresh_token` is not actually a JWT refresh token - it's a # gen3 api key with a token as the "api_key" property self._refresh_token = refresh_token self._access_token = access_token self._access_token_info = None self._wts_idp = idp or "local" self._wts_namespace = os.environ.get("NAMESPACE", "default") self._use_wts = False self._external_wts_host = None self._refresh_file = refresh_file self._client_credentials = client_credentials if self._client_credentials: self._client_scopes = client_scopes or "user data openid" elif client_scopes: raise ValueError( "'client_scopes' cannot be specified without 'client_credentials'" ) if refresh_file and refresh_token: raise ValueError( "Only one of 'refresh_file' and 'refresh_token' can be specified." ) if endpoint and idp: raise ValueError("Only one of 'endpoint' and 'idp' can be specified.") if not refresh_file and not refresh_token and not idp: refresh_file = os.getenv("GEN3_API_KEY", "credentials") if refresh_file and not idp: idp_prefix = "idp://wts/" access_token_prefix = "accesstoken:///" if refresh_file[0 : len(idp_prefix)] == idp_prefix: idp = refresh_file[len(idp_prefix) :] refresh_file = None elif refresh_file[0 : len(access_token_prefix)] == access_token_prefix: self._access_token = refresh_file[len(access_token_prefix) :] self._access_token_info = decode_token(self._access_token) refresh_file = None elif ( not os.path.isfile(refresh_file) and "/" not in refresh_file and "\\" not in refresh_file ): refresh_file = "{}/.gen3/{}".format( os.path.expanduser("~"), refresh_file ) if not os.path.isfile(refresh_file) and refresh_file[-5:] != ".json": refresh_file += ".json" if not os.path.isfile(refresh_file): logging.warning("Unable to find refresh_file") refresh_file = None if self._client_credentials: if not endpoint: raise ValueError( "'endpoint' must be specified when '_client_credentials' is specified" ) self._access_token = get_access_token_with_client_credentials( endpoint, self._client_credentials, self._client_scopes ) if not self._access_token: # at this point - refresh_file either exists or is None if not refresh_file and not refresh_token: # check if this is a Gen3 workspace environment # most production environments are in the "default" namespace # attempt to get a token from the workspace-token-service self._use_wts = True # hate calling a method from the constructor, but avoids copying code self.get_access_token() elif refresh_file: try: with open(refresh_file) as f: file_data = f.read() self._refresh_token = json.loads(file_data) except Exception as e: raise ValueError( "Couldn't load your refresh token file: {}\n{}".format( refresh_file, str(e) ) ) assert "api_key" in self._refresh_token # if both endpoint and refresh file are provided, compare endpoint with iss in refresh file # May need to use network wts endpoint if idp or ( endpoint and ( not endpoint.rstrip("/") == endpoint_from_token(self._refresh_token["api_key"]) ) ): try: logging.debug( "Switch to using WTS and set external WTS host url.." ) self._use_wts = True self._external_wts_host = ( endpoint_from_token(self._refresh_token["api_key"]) + "/wts/" ) self.get_access_token() except Gen3AuthError as g: logging.warning( "Could not obtain access token from WTS service." ) raise g if not self.endpoint: if self._access_token: self.endpoint = endpoint_from_token(self._access_token) else: self.endpoint = endpoint_from_token(self._refresh_token["api_key"]) @property def _token_info(self): """ Wrapper to fix intermittent errors when the token is being refreshed and `_access_token_info` == None """ if not self._access_token_info: self.refresh_access_token() return self._access_token_info def __call__(self, request): """Adds authorization header to the request This gets called by the python.requests package on outbound requests so that authentication can be added. Args: request (object): The incoming request object """ request.headers["Authorization"] = self._get_auth_value() request.register_hook("response", self._handle_401) return request def _handle_401(self, response, **kwargs): """Handles failed requests when authorization failed. This gets called after a failed request when an HTTP 401 error occurs. This then tries to refresh the access token in the event that it expired. Args: request (object): The failed request object """ if not response.status_code == 401 and not response.status_code == 403: return response # Free the original connection response.content response.close() # copy the request to resend newreq = response.request.copy() newreq.headers["Authorization"] = self._get_auth_value() _response = response.connection.send(newreq, **kwargs) _response.history.append(response) _response.request = newreq return _response
[docs] def refresh_access_token(self, endpoint=None): """Get a new access token""" if self._use_wts: self._access_token = self.get_access_token_from_wts(endpoint) elif self._client_credentials: self._access_token = get_access_token_with_client_credentials( endpoint, self._client_credentials, self._client_scopes ) elif self._refresh_token: self._access_token = get_access_token_with_key(self._refresh_token) else: logging.warning( f"Unable to refresh access token. " f"Authorized API calls will stop working when this token expires." ) self._access_token_info = decode_token(self._access_token) cache_file = None if self._use_wts: cache_file = get_token_cache_file_name(self._wts_idp) elif self._refresh_file: cache_file = get_token_cache_file_name(self._refresh_token["api_key"]) if cache_file: try: self._write_to_file(cache_file, self._access_token) except Exception as e: logging.warning( f"Unable to write access token to cache file. Exceeded number of retries. Details: {e}" ) return self._access_token
@backoff.on_exception( wait_gen=backoff.expo, exception=Exception, **DEFAULT_BACKOFF_SETTINGS ) def _write_to_file(self, cache_file, content): # write a temp file, then rename - to avoid # simultaneous writes to same file race condition temp = cache_file + ( ".tmp_eraseme_%d_%d" % (random.randrange(100000), time.time()) ) try: with open(temp, "w") as f: f.write(content) os.rename(temp, cache_file) return True except Exception as e: logging.warning("failed to write token cache file: " + cache_file) logging.warning(str(e)) raise e
[docs] def get_access_token(self): """Get the access token - auto refresh if within 5 minutes of expiration""" if not self._access_token: if self._use_wts == True: cache_file = get_token_cache_file_name(self._wts_idp) else: if self._refresh_token: cache_file = get_token_cache_file_name( self._refresh_token["api_key"] ) if cache_file and os.path.isfile(cache_file): try: # don't freak out on invalid cache with open(cache_file) as f: self._access_token = f.read() self._access_token_info = decode_token(self._access_token) except Exception as e: logging.warning("ignoring invalid token cache: " + cache_file) self._access_token = None self._access_token_info = None logging.warning(str(e)) need_new_token = ( not self._access_token or not self._access_token_info or time.time() + 300 > self._access_token_info["exp"] ) if need_new_token: return self.refresh_access_token( self.endpoint if hasattr(self, "endpoint") else None ) # use cache return self._access_token
def _get_auth_value(self): """Returns the Authorization header value for the request This gets called when added the Authorization header to the request. This fetches the access token from the refresh token if the access token is missing. """ return "bearer " + self.get_access_token()
[docs] def curl(self, path, request=None, data=None): """ Curl the given endpoint - ex: gen3 curl /user/user. Return requests.Response Args: path (str): path under the commons to curl (/user/user, /index/index, /authz/mapping, ...) request (str in GET|POST|PUT|DELETE): default to GET if data is not set, else default to POST data (str): json string or "@filename" of a json file """ if not request: request = "GET" if data: request = "POST" json_data = data output = None if data and data[0] == "@": with open(data[1:]) as f: json_data = f.read() if request == "GET": output = requests.get(self.endpoint + "/" + path, auth=self) elif request == "POST": output = requests.post( self.endpoint + "/" + path, json=json_data, auth=self ) elif request == "PUT": output = requests.put(self.endpoint + "/" + path, json=json_data, auth=self) elif request == "DELETE": output = requests.delete(self.endpoint + "/" + path, auth=self) else: raise Exception("Invalid request type: " + request) return output
[docs] def get_access_token_from_wts(self, endpoint=None): """ Try to fetch an access token for the given idp from the wts in the given namespace. If idp is not set, then default to "local" """ # attempt to get a token from the workspace-token-service logging.debug("getting access token from wts..") auth_url = get_wts_endpoint(self._wts_namespace) + "/token/" # If non "local" idp value exists, append to auth url # If user specified endpoint value, then first attempt to determine idp value. if self.endpoint or (self._wts_idp and self._wts_idp != "local"): # If user supplied endpoint value and not idp, figure out the idp value if self.endpoint: logging.debug( "First try to use the local WTS to figure out idp name for the supplied endpoint.." ) try: provider_List = get_wts_idps(self._wts_namespace) matchProviders = list( filter( lambda provider: provider["base_url"] == endpoint, provider_List["providers"], ) ) if len(matchProviders) == 1: logging.debug("Found matching idp from local WTS.") self._wts_idp = matchProviders[0]["idp"] elif len(matchProviders) > 1: raise ValueError( "Multiple idps matched with endpoint value provided." ) else: logging.debug("Could not find matching idp from local WTS.") except Exception as e: logging.debug( "Exception occured when making network call to local WTS." ) if not self._external_wts_host: raise e else: logging.debug("Since external WTS host exists, continuing on..") pass if self._wts_idp and self._wts_idp != "local": auth_url += "?idp={}".format(self._wts_idp) # If endpoint value exists, only get WTS token if idp value has been successfully determined # Otherwise skip to querying external WTS # This is to prevent local WTS from supplying an incorrect token to user if ( not self._external_wts_host or not self.endpoint or (self.endpoint and self._wts_idp != "local") ): try: logging.debug("Try to get access token from local WTS..") logging.debug(f"{auth_url=}") resp = requests.get(auth_url) if (resp and resp.status_code == 200) or (not self._external_wts_host): return _handle_access_token_response(resp, "token") except Exception as e: if not self._external_wts_host: raise e else: # Try to obtain token from external wts logging.debug("Could get obtain token from Local WTS.") pass # local workspace wts call failed, try using a network call # First get access token with WTS host logging.debug("Trying to get access token from external WTS Host..") wts_token = get_access_token_with_key(self._refresh_token) auth_url = self._external_wts_host + "token/" provider_List = get_wts_idps(self._wts_namespace, self._external_wts_host) # if user already supplied idp, use that if self._wts_idp and self._wts_idp != "local": matchProviders = list( filter( lambda provider: provider["idp"] == self._wts_idp, provider_List["providers"], ) ) elif endpoint: matchProviders = list( filter( lambda provider: provider["base_url"] == endpoint, provider_List["providers"], ) ) else: raise Exception( "Unable to generate matching identity providers (no IdP or endpoint provided)" ) if len(matchProviders) == 1: self._wts_idp = matchProviders[0]["idp"] logging.debug("Succesfully determined idp value: {}".format(self._wts_idp)) else: idp_list = "\n " if len(matchProviders) > 1: for idp in matchProviders: idp_list = ( idp_list + "idp name: " + idp["idp"] + " url: " + idp["base_url"] + "\n " ) raise ValueError( "Multiple idps matched with endpoint value provided." + idp_list + "Query /wts/external_oidc/ for more information." ) else: for idp in provider_List["providers"]: idp_list = ( idp_list + "idp name: " + idp["idp"] + " Endpoint url: " + idp["base_url"] + "\n " ) raise ValueError( "No idp matched with the endpoint or idp value provided.\n" + "Please make sure your endpoint or idp value matches exactly with the output below.\n" + "i.e. check trailing '/' character for the endpoint url\n" + "Available Idps:" + idp_list + "Query /wts/external_oidc/ for more information." ) logging.debug("Finally getting access token..") auth_url += "?idp={}".format(self._wts_idp) header = {"Authorization": "Bearer " + wts_token} resp = requests.get(auth_url, headers=header) err_msg = "Please make sure the target commons is connected on your profile page and that connection has not expired." if resp.status_code != 200: logging.warning(err_msg) return _handle_access_token_response(resp, "token")