import urllib
import requests
import json
import logging
logger = logging.getLogger(__name__)
import webbrowser
from .session_config import patch_session
import numpy as np
import datetime
import pandas as pd
[docs]class BaseEncoder(json.JSONEncoder):
[docs] def default(self, obj):
if isinstance(obj, np.ndarray) or isinstance(obj, pd.Series):
return obj.tolist()
if isinstance(obj, set):
return list(obj)
if isinstance(obj, np.uint64):
return int(obj)
if isinstance(obj, np.int64):
return int(obj)
if isinstance(obj, (datetime.datetime, datetime.date)):
return obj.isoformat()
return json.JSONEncoder.default(self, obj)
[docs]class AuthException(Exception):
pass
def _raise_for_status(r, log_warning=True):
http_error_msg = ""
if isinstance(r.reason, bytes):
# We attempt to decode utf-8 first because some servers
# choose to localize their reason strings. If the string
# isn't utf-8, we fall back to iso-8859-1 for all other
# encodings. (See PR #3538)
try:
reason = r.reason.decode("utf-8")
except UnicodeDecodeError:
reason = r.reason.decode("iso-8859-1")
else:
reason = r.reason
if 400 <= r.status_code < 500:
http_error_msg = "%s Client Error: %s for url: %s content: %s" % (
r.status_code,
reason,
r.url,
r.content,
)
json_data = None
if r.headers.get("content-type") == "application/json":
json_data = r.json()
if r.status_code == 403:
if json_data:
if "error" in json_data.keys():
if json_data["error"] == "missing_tos":
webbrowser.open(json_data["data"]["tos_form_url"])
elif 500 <= r.status_code < 600:
try:
d = json.loads(r.content)
reason = d.get("message", reason)
except json.decoder.JSONDecodeError:
pass
http_error_msg = "%s Server Error: %s for url: %s content:%s" % (
r.status_code,
reason,
r.url,
r.content,
)
if http_error_msg:
raise requests.HTTPError(http_error_msg, response=r)
if log_warning:
warning = r.headers.get("Warning")
if warning:
logger.warning(warning)
[docs]def handle_response(response, as_json=True, log_warning=True):
"""Deal with potential errors in endpoint response and return json for default case"""
_raise_for_status(response, log_warning=log_warning)
_check_authorization_redirect(response)
if as_json:
return response.json()
else:
return response
def _check_authorization_redirect(response):
if len(response.history) == 0:
pass
else:
first_url = response.history[0].url
urlp = urllib.parse.urlparse(first_url)
raise AuthException(
f"""You have not setup a token to access
{first_url}
with the current auth configuration.\n
Read the documentation at
https://caveclient.readthedocs.io/en/latest/guide/authentication.html
or follow instructions under
client.auth.get_new_token() for how to set a valid API token.
after initializing a global client with
client=CAVEclient(server_address="{urlp.scheme +"://"+ urlp.netloc}")"""
)
def _api_versions(server_name, server_address, endpoints_common, auth_header, verify=True):
"""Asks a server what API versions are available, if possible"""
url_mapping = {server_name: server_address}
url_base = endpoints_common.get("get_api_versions", None)
if url_base is not None:
url = url_base.format_map(url_mapping)
response = requests.get(url, headers=auth_header, verify=verify)
_raise_for_status(response)
return response.json()
else:
return None
def _api_endpoints(
api_version,
server_name,
server_address,
endpoints_common,
endpoint_versions,
auth_header,
fallback_version=None,
verify=True,
):
"Gets the latest client API version"
if api_version == "latest":
try:
avail_vs_server = _api_versions(
server_name, server_address, endpoints_common, auth_header, verify=verify
)
avail_vs_server = set(avail_vs_server)
except:
avail_vs_server = None
avail_vs_client = set(endpoint_versions.keys())
if avail_vs_server is None:
if fallback_version is None:
api_version = max(avail_vs_client)
else:
api_version = fallback_version
else:
api_version = max(avail_vs_client.intersection(avail_vs_server))
endpoints = endpoints_common.copy()
ep_to_add = endpoint_versions.get(api_version, None)
if ep_to_add is None:
raise ValueError("No corresponding API version")
endpoints.update(ep_to_add)
return endpoints, api_version
[docs]class ClientBase(object):
def __init__(
self,
server_address,
auth_header,
api_version,
endpoints,
server_name,
verify=True,
max_retries=None,
pool_maxsize=None,
pool_block=None,
over_client=None,
):
self._server_address = server_address
self._default_url_mapping = {server_name: self._server_address}
self.verify = verify
self.session = requests.Session()
patch_session(
self.session,
max_retries=max_retries,
pool_block=pool_block,
pool_maxsize=pool_maxsize,
)
self.session.verify = verify
head_val = auth_header.get("Authorization", None)
if head_val is not None:
token = head_val.split(" ")[1]
cookie_obj = requests.cookies.create_cookie(
name="middle_auth_token", value=token
)
self.session.cookies.set_cookie(cookie_obj)
self.session.headers.update(auth_header)
self._api_version = api_version
self._endpoints = endpoints
self._fc = over_client
@property
def fc(self):
return self._fc
@property
def default_url_mapping(self):
return self._default_url_mapping
@property
def server_address(self):
return self._server_address
@property
def api_version(self):
return self._api_version
[docs] @staticmethod
def raise_for_status(r, log_warning=True):
"""Raises :class:`HTTPError`, if one occurred."""
_raise_for_status(r, log_warning=log_warning)
[docs]class ClientBaseWithDataset(ClientBase):
def __init__(
self,
server_address,
auth_header,
api_version,
endpoints,
server_name,
dataset_name,
verify=True,
max_retries=None,
pool_maxsize=None,
pool_block=None,
over_client=None,
):
super(ClientBaseWithDataset, self).__init__(
server_address,
auth_header,
api_version,
endpoints,
server_name,
verify=verify,
max_retries=max_retries,
pool_maxsize=pool_maxsize,
pool_block=pool_block,
over_client=over_client,
)
self._dataset_name = dataset_name
@property
def dataset_name(self):
return self._dataset_name
[docs]class ClientBaseWithDatastack(ClientBase):
def __init__(
self,
server_address,
auth_header,
api_version,
endpoints,
server_name,
datastack_name,
verify=True,
max_retries=None,
pool_maxsize=None,
pool_block=None,
over_client=None,
):
super(ClientBaseWithDatastack, self).__init__(
server_address,
auth_header,
api_version,
endpoints,
server_name,
verify=verify,
max_retries=max_retries,
pool_maxsize=pool_maxsize,
pool_block=pool_block,
over_client=over_client,
)
self._datastack_name = datastack_name
@property
def datastack_name(self):
return self._datastack_name