"""PyChunkedgraph service python interface"""
import datetime
import json
import logging
from typing import Iterable, Tuple, Union
from urllib.parse import urlencode
import networkx as nx
import numpy as np
import pandas as pd
import pytz
from .auth import AuthClient
from .base import BaseEncoder, ClientBase, _api_endpoints, handle_response
from .endpoints import (
chunkedgraph_api_versions,
chunkedgraph_endpoints_common,
default_global_server_address,
)
SERVER_KEY = "cg_server_address"
logger = logging.getLogger(__name__)
[docs]def package_bounds(bounds):
if bounds.shape != (3, 2):
raise ValueError(
"Bounds must be a 3x2 matrix (x,y,z) x (min,max) in chunkedgraph resolution voxel units"
)
bounds_str = []
for b in bounds:
bounds_str.append("-".join(str(b2) for b2 in b))
bounds_str = "_".join(bounds_str)
return bounds_str
[docs]def package_timestamp(timestamp, name="timestamp"):
if timestamp is None:
query_d = {}
else:
if timestamp.tzinfo is None:
timestamp = pytz.UTC.localize(timestamp)
else:
timestamp = timestamp.astimezone(datetime.timezone.utc)
query_d = {name: timestamp.timestamp()}
return query_d
[docs]def package_split_data(
root_id, source_points, sink_points, source_supervoxels, sink_supervoxels
):
"""Create the data for preview or executed split operations"""
categories = ["sources", "sinks"]
pts = [source_points, sink_points]
svs = [source_supervoxels, sink_supervoxels]
for pt_list, sv_list in zip(pts, svs):
if sv_list is not None:
if len(pt_list) != len(sv_list):
raise ValueError(
"If supervoxels are provided, they must have the same length as points"
)
data = {}
for cat, pt_list, sv_list in zip(categories, pts, svs):
if sv_list is None:
sv_list = [None] * len(pt_list)
sv_list = [x if x is not None else root_id for x in sv_list]
out = []
for svid, pt in zip(sv_list, pt_list):
out.append([svid, pt[0], pt[1], pt[2]])
data[cat] = out
return data
[docs]def root_id_int_list_check(
root_id,
make_unique=False,
):
if isinstance(root_id, (int, np.uint64, np.int64)):
root_id = [root_id]
elif isinstance(root_id, str):
try:
root_id = np.uint64(root_id)
except ValueError:
raise ValueError(
"When passing a string for 'root_id' make sure the string can be converted to a uint64"
)
elif isinstance(root_id, np.ndarray) or isinstance(root_id, list):
if make_unique:
root_id = np.unique(root_id).astype(np.uint64)
else:
root_id = np.array(root_id, dtype=np.uint64)
else:
raise ValueError("root_id has to be list or uint64")
return root_id
[docs]def ChunkedGraphClient(
server_address=None,
table_name=None,
auth_client=None,
api_version="latest",
timestamp=None,
verify=True,
max_retries=None,
pool_maxsize=None,
pool_block=None,
over_client=None,
):
if server_address is None:
server_address = default_global_server_address
if auth_client is None:
auth_client = AuthClient()
auth_header = auth_client.request_header
endpoints, api_version = _api_endpoints(
api_version,
SERVER_KEY,
server_address,
chunkedgraph_endpoints_common,
chunkedgraph_api_versions,
auth_header,
verify=verify,
)
ChunkedClient = client_mapping[api_version]
return ChunkedClient(
server_address,
auth_header,
api_version,
endpoints,
SERVER_KEY,
timestamp=timestamp,
table_name=table_name,
verify=verify,
max_retries=max_retries,
pool_maxsize=pool_maxsize,
pool_block=pool_block,
over_client=over_client,
)
[docs]class ChunkedGraphClientV1(ClientBase):
"""ChunkedGraph Client for the v1 API"""
def __init__(
self,
server_address,
auth_header,
api_version,
endpoints,
server_key=SERVER_KEY,
timestamp=None,
table_name=None,
verify=True,
max_retries=None,
pool_maxsize=None,
pool_block=None,
over_client=None,
):
super(ChunkedGraphClientV1, self).__init__(
server_address,
auth_header,
api_version,
endpoints,
server_key,
verify=verify,
max_retries=max_retries,
pool_maxsize=pool_maxsize,
pool_block=pool_block,
over_client=over_client,
)
self._default_url_mapping["table_id"] = table_name
self._default_timestamp = timestamp
self._table_name = table_name
self._segmentation_info = None
@property
def default_url_mapping(self):
return self._default_url_mapping.copy()
@property
def table_name(self):
return self._table_name
def _process_timestamp(self, timestamp):
"""Process timestamp with default logic"""
if timestamp is None:
if self._default_timestamp is not None:
return self._default_timestamp
else:
return datetime.datetime.now(datetime.timezone.utc)
else:
return timestamp
[docs] def get_roots(self, supervoxel_ids, timestamp=None, stop_layer=None) -> np.ndarray:
"""Get the root ID for a list of supervoxels.
Parameters
----------
supervoxel_ids : list or np.array of int
Supervoxel IDs to look up.
timestamp : datetime.datetime, optional
UTC datetime to specify the state of the chunkedgraph at which to query, by
default None. If None, uses the current time.
stop_layer : int or None, optional
If True, looks up IDs only up to a given stop layer. Default is None.
Returns
-------
np.array of np.uint64
Root IDs containing each supervoxel.
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["get_roots"].format_map(endpoint_mapping)
query_d = package_timestamp(self._process_timestamp(timestamp))
if stop_layer is not None:
query_d["stop_layer"] = stop_layer
data = np.array(supervoxel_ids, dtype=np.uint64).tobytes()
response = self.session.post(url, data=data, params=query_d)
handle_response(response, as_json=False)
return np.frombuffer(response.content, dtype=np.uint64)
[docs] def get_root_id(self, supervoxel_id, timestamp=None, level2=False) -> np.int64:
"""Get the root ID for a specified supervoxel.
Parameters
----------
supervoxel_id : int
Supervoxel id value
timestamp : datetime.datetime, optional
UTC datetime to specify the state of the chunkedgraph at which to query, by
default None. If None, uses the current time.
Returns
-------
np.int64
Root ID containing the supervoxel.
"""
endpoint_mapping = self.default_url_mapping
endpoint_mapping["supervoxel_id"] = supervoxel_id
url = self._endpoints["handle_root"].format_map(endpoint_mapping)
query_d = package_timestamp(self._process_timestamp(timestamp))
if level2:
query_d["stop_layer"] = 2
response = self.session.get(url, params=query_d)
return np.int64(handle_response(response, as_json=True)["root_id"])
[docs] def get_merge_log(self, root_id) -> list:
"""Get the merge log (splits and merges) for an object.
Parameters
----------
root_id : int
Object root ID to look up.
Returns
-------
list
List of merge events in the history of the object.
"""
endpoint_mapping = self.default_url_mapping
endpoint_mapping["root_id"] = root_id
url = self._endpoints["merge_log"].format_map(endpoint_mapping)
response = self.session.get(url)
return handle_response(response)
[docs] def get_change_log(self, root_id, filtered=True) -> dict:
"""Get the change log (splits and merges) for an object.
Parameters
----------
root_id : int
Object root ID to look up.
filtered : bool
Whether to filter the change log to only include splits and merges which
affect the final state of the object (`filtered=True`), as opposed to
including edit history for objects which as some point were split from
the query object `root_id` (`filtered=False`). Defaults to True.
Returns
-------
dict
Dictionary summarizing split and merge events in the object history,
containing the following keys:
"n_merges": int
Number of merges
"n_splits": int
Number of splits
"operations_ids": list of int
Identifiers for each operation
"past_ids": list of int
Previous root ids for this object
"user_info": dict of dict
Dictionary keyed by user (string) to a dictionary specifying how many
merges and splits that user performed on this object
"""
endpoint_mapping = self.default_url_mapping
endpoint_mapping["root_id"] = root_id
url = self._endpoints["change_log"].format_map(endpoint_mapping)
params = {"filtered": filtered}
response = self.session.get(url, params=params)
return handle_response(response)
[docs] def get_user_operations(
self,
user_id: int,
timestamp_start: datetime.datetime,
include_undo: bool = True,
timestamp_end: datetime.datetime = None,
) -> pd.DataFrame:
"""
Get operation details for a user ID. Currently, this is only available to
admins.
Parameters
----------
user_id : int
User ID to query (use 0 for all users (admin only)).
timestamp_start : datetime.datetime, optional
Timestamp to start filter (UTC).
include_undo : bool, optional
Whether to include undos. Defaults to True.
timestamp_end : datetime.datetime, optional
Timestamp to end filter (UTC). Defaults to now.
Returns
-------
pd.DataFrame
DataFrame including the following columns:
"operation_id": int
Identifier for the operation.
"timestamp": datetime.datetime
Timestamp of the operation.
"user_id": int
User who performed the operation.
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["user_operations"].format_map(endpoint_mapping)
params = {"include_undo": include_undo}
if user_id > 0:
params = {"user_id": user_id}
if timestamp_start is not None:
params.update(
package_timestamp(
self._process_timestamp(timestamp_start), "start_time"
)
)
if timestamp_end is not None:
params.update(
package_timestamp(self._process_timestamp(timestamp_end), "end_time")
)
response = self.session.get(url, params=params)
d = handle_response(response)
df = pd.DataFrame(d)
df["timestamp"] = df["timestamp"].map(
lambda x: datetime.datetime.fromtimestamp(x / 1000, pytz.UTC)
)
return df
[docs] def get_tabular_change_log(self, root_ids, filtered=True) -> dict:
"""Get a detailed changelog for neurons.
Parameters
----------
root_ids : list of int
Object root IDs to look up.
filtered : bool
Whether to filter the change log to only include splits and merges which
affect the final state of the object (`filtered=True`), as opposed to
including edit history for objects which as some point were split from
the query objects in `root_ids` (`filtered=False`). Defaults to True.
Returns
-------
dict of pd.DataFrame
The keys are the root IDs, and the values are DataFrames with the
following columns and datatypes:
"operation_id": int
Identifier for the operation.
"timestamp": int
Timestamp of the operation, provided in *milliseconds*. To convert to
datetime, use ``datetime.datetime.utcfromtimestamp(timestamp/1000)``.
"user_id": int
User who performed the operation.
"before_root_ids: list of int
Root IDs of objects that existed before the operation.
"after_root_ids: list of int
Root IDs of objects created by the operation. Note that this only
records the root id that was kept as part of the query object, so there
will only be one in this list.
"is_merge": bool
Whether the operation was a merge.
"user_name": str
Name of the user who performed the operation.
"user_affiliation": str
Affiliation of the user who performed the operation.
"""
root_ids = [int(r) for r in np.unique(root_ids)]
endpoint_mapping = self.default_url_mapping
endpoint_mapping["root_ids"] = root_ids
url = self._endpoints["tabular_change_log"].format_map(endpoint_mapping)
params = {"filtered": filtered}
data = json.dumps({"root_ids": root_ids}, cls=BaseEncoder)
response = self.session.get(url, data=data, params=params)
res_dict = handle_response(response)
changelog_dict = {}
for k in res_dict.keys():
changelog_dict[int(k)] = pd.DataFrame(json.loads(res_dict[k]))
return changelog_dict
[docs] def get_leaves(self, root_id, bounds=None, stop_layer: int = None) -> np.ndarray:
"""Get all supervoxels for a root ID.
Parameters
----------
root_id : int
Root ID to query.
bounds: np.array or None, optional
If specified, returns supervoxels within a 3x2 numpy array of bounds
``[[minx,maxx],[miny,maxy],[minz,maxz]]``. If None, finds all supervoxels.
stop_layer: int, optional
If specified, returns chunkedgraph nodes at layer `stop_layer`
default will be `stop_layer=1` (supervoxels).
Returns
-------
np.array of np.int64
Array of supervoxel IDs (or node ids if `stop_layer>1`).
"""
endpoint_mapping = self.default_url_mapping
endpoint_mapping["root_id"] = root_id
url = self._endpoints["leaves_from_root"].format_map(endpoint_mapping)
query_d = {}
if bounds is not None:
query_d["bounds"] = package_bounds(bounds)
if stop_layer is not None:
query_d["stop_layer"] = int(stop_layer)
response = self.session.get(url, params=query_d)
return np.int64(handle_response(response)["leaf_ids"])
[docs] def do_merge(self, supervoxels, coords, resolution=(4, 4, 40)) -> None:
"""Perform a merge on the chunked graph.
Parameters
----------
supervoxels : iterable
An N-long list of supervoxels to merge.
coords : np.array
An Nx3 array of coordinates of the supervoxels in units of `resolution`.
resolution : tuple, optional
What to multiply `coords` by to get nanometers. Defaults to (4,4,40).
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["do_merge"].format_map(endpoint_mapping)
data = []
for svid, coor in zip(supervoxels, coords):
pos = np.array(coor) * resolution
row = [svid, pos[0], pos[1], pos[2]]
data.append(row)
params = {"priority": False}
response = self.session.post(
url,
data=json.dumps(data, cls=BaseEncoder),
params=params,
headers={"Content-Type": "application/json"},
)
handle_response(response)
[docs] def undo_operation(self, operation_id) -> dict:
"""Undo an operation.
Parameters
----------
operation_id : int
Operation ID to undo.
Returns
-------
dict
"""
# TODO clarify what the return is here
endpoint_mapping = self.default_url_mapping
url = self._endpoints["undo"].format_map(endpoint_mapping)
data = {"operation_id": operation_id}
params = {"priority": False}
response = self.session.post(
url,
data=json.dumps(data, cls=BaseEncoder),
params=params,
headers={"Content-Type": "application/json"},
)
r = handle_response(response)
return r
[docs] def execute_split(
self,
source_points,
sink_points,
root_id,
source_supervoxels=None,
sink_supervoxels=None,
) -> Tuple[int, list]:
"""Execute a multicut split based on points or supervoxels.
Parameters
----------
source_points : array or list
Nx3 list or array of 3d points in nm coordinates for source points (red).
sink_points : array or list
Mx3 list or array of 3d points in nm coordinates for sink points (blue).
root_id : int
Root ID of object to do split preview.
source_supervoxels : array, list or None, optional
If providing source supervoxels, an N-length array of supervoxel IDs or
Nones matched to source points. If None, treats as a full array of Nones.
By default None.
sink_supervoxels : array, list or None, optional
If providing sink supervoxels, an M-length array of supervoxel IDs or Nones
matched to source points. If None, treats as a full array of Nones.
By default None.
Returns
-------
operation_id : int
Unique ID of the split operation
new_root_ids : list of int
List of new root IDs resulting from the split operation.
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["execute_split"].format_map(endpoint_mapping)
data = package_split_data(
root_id, source_points, sink_points, source_supervoxels, sink_supervoxels
)
params = {"priority": False}
response = self.session.post(
url,
data=json.dumps(data, cls=BaseEncoder),
params=params,
headers={"Content-Type": "application/json"},
)
r = handle_response(response)
return r["operation_id"], r["new_root_ids"]
[docs] def preview_split(
self,
source_points,
sink_points,
root_id,
source_supervoxels=None,
sink_supervoxels=None,
return_additional_ccs=False,
) -> Tuple[list, list, bool, list]:
"""Get supervoxel connected components from a preview multicut split.
Parameters
----------
source_points : array or list
Nx3 list or array of 3d points in nm coordinates for source points (red).
sink_points : array or list
Mx3 list or array of 3d points in nm coordinates for sink points (blue).
root_id : int
Root ID of object to do split preview.
source_supervoxels : array, list or None, optional
If providing source supervoxels, an N-length array of supervoxel IDs or
Nones matched to source points. If None, treats as a full array of Nones.
By default None.
sink_supervoxels : array, list or None, optional
If providing sink supervoxels, an M-length array of supervoxel IDs or Nones
matched to source points. If None, treats as a full array of Nones.
By default None.
return_additional_ccs : bool, optional
If True, returns any additional connected components beyond the ones with
source and sink points. In most situations, this can be ignored.
By default, False.
Returns
-------
source_connected_component : list
Supervoxel IDs in the component with the most source points.
sink_connected_component : list
Supervoxel IDs in the component with the most sink points.
successful_split : bool
True if the split worked.
other_connected_components (optional) : list of lists of int
List of lists of supervoxel IDs for any other resulting connected components.
Only returned if `return_additional_ccs` is True.
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["preview_split"].format_map(endpoint_mapping)
data = package_split_data(
root_id, source_points, sink_points, source_supervoxels, sink_supervoxels
)
response = self.session.post(
url,
data=json.dumps(data, cls=BaseEncoder),
headers={"Content-Type": "application/json"},
)
r = handle_response(response)
source_cc = r["supervoxel_connected_components"][0]
sink_cc = r["supervoxel_connected_components"][1]
if len(r["supervoxel_connected_components"]) == 2:
other_ccs = []
else:
other_ccs = r["supervoxel_connected_components"][2:]
success = not r["illegal_split"]
if return_additional_ccs:
return source_cc, sink_cc, success, other_ccs
else:
return source_cc, sink_cc, success
[docs] def get_children(self, node_id) -> np.ndarray:
"""Get the children of a node in the chunked graph hierarchy.
Parameters
----------
node_id : int
Node ID to query.
Returns
-------
np.array of np.int64
IDs of child nodes.
"""
endpoint_mapping = self.default_url_mapping
endpoint_mapping["root_id"] = node_id
url = self._endpoints["handle_children"].format_map(endpoint_mapping)
response = self.session.get(url)
return np.array(handle_response(response)["children_ids"], dtype=np.int64)
[docs] def find_path(
self, root_id, src_pt, dst_pt, precision_mode=False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Find a path between two locations on a root ID using the level 2 chunked
graph.
Parameters
----------
root_id : int
Root ID to query.
src_pt : np.array
3-element array of xyz coordinates in nm for the source point.
dst_pt : np.array
3-element array of xyz coordinates in nm for the destination point.
precision_mode : bool, optional
Whether to perform the search in precision mode. Defaults to False.
Returns
-------
centroids_list : np.array
Array of centroids along the path.
l2_path : np.array of int
Array of level 2 chunk IDs along the path.
failed_l2_ids : np.array of int
Array of level 2 chunk IDs that failed to find a path.
"""
endpoint_mapping = self.default_url_mapping
endpoint_mapping["root_id"] = root_id
url = self._endpoints["find_path"].format_map(endpoint_mapping)
query_d = {}
query_d["precision_mode"] = precision_mode
nodes = [[root_id] + src_pt.tolist(), [root_id] + dst_pt.tolist()]
response = self.session.post(
url,
data=json.dumps(nodes, cls=BaseEncoder),
params=query_d,
headers={"Content-Type": "application/json"},
)
resp_d = handle_response(response)
centroids = np.array(resp_d["centroids_list"])
failed_l2_ids = np.array(resp_d["failed_l2_ids"], dtype=np.uint64)
l2_path = np.array(resp_d["l2_path"])
return centroids, l2_path, failed_l2_ids
[docs] def get_subgraph(
self, root_id, bounds
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Get subgraph of root id within a bounding box.
Parameters
----------
root_id : int
Root (or any node ID) of chunked graph to query.
bounds : np.array
3x2 bounding box (x,y,z) x (min,max) in chunked graph coordinates.
Returns
-------
np.array of np.int64
Node IDs in the subgraph.
np.array of np.double
Affinities of edges in the subgraph.
np.array of np.int32
Areas of nodes in the subgraph.
"""
endpoint_mapping = self.default_url_mapping
endpoint_mapping["root_id"] = root_id
url = self._endpoints["get_subgraph"].format_map(endpoint_mapping)
query_d = {}
if bounds is not None:
query_d["bounds"] = package_bounds(bounds)
response = self.session.get(url, params=query_d)
rd = handle_response(response)
return np.int64(rd["nodes"]), np.double(rd["affinities"]), np.int32(rd["areas"])
[docs] def level2_chunk_graph(self, root_id) -> list:
"""
Get graph of level 2 chunks, the smallest agglomeration level above supervoxels.
Parameters
----------
root_id : int
Root id of object
Returns
-------
list of list
Edge list for level 2 chunked graph. Each element of the list is an edge,
and each edge is a list of two node IDs (source and target).
"""
endpoint_mapping = self.default_url_mapping
endpoint_mapping["root_id"] = root_id
url = self._endpoints["lvl2_graph"].format_map(endpoint_mapping)
r = handle_response(self.session.get(url))
return r["edge_graph"]
[docs] def remesh_level2_chunks(self, chunk_ids) -> None:
"""Submit specific level 2 chunks to be remeshed in case of a problem.
Parameters
----------
chunk_ids : list
List of level 2 chunk IDs.
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["remesh_level2_chunks"].format_map(endpoint_mapping)
data = {"new_lvl2_ids": [int(x) for x in chunk_ids]}
r = self.session.post(url, json=data)
r.raise_for_status()
[docs] def get_operation_details(self, operation_ids: Iterable[int]) -> dict:
"""Get the details of a list of operations.
Parameters
----------
operation_ids: Iterable of int
List/array of operation IDs.
Returns
-------
dict of str to dict
A dict of dicts of operation info, keys are operation IDs (as strings),
values are a dictionary of operation info for the operation. These
dictionaries contain the following keys:
"added_edges"/"removed_edges": list of list of int
List of edges added (if a merge) or removed (if a split) by this
operation. Each edge is a list of two supervoxel IDs (source and
target).
"roots": list of int
List of root IDs that were created by this operation.
"sink_coords": list of list of int
List of sink coordinates for this operation. The sink is one of the
points placed by the user when specifying the operation. Each sink
coordinate is a list of three integers (x, y, z), corresponding to
spatial coordinates in segmentation voxel space.
"source_coords": list of list of int
List of source coordinates for this operation. The source is one of the
points placed by the user when specifying the operation. Each source
coordinate is a list of three integers (x, y, z), corresponding to
spatial coordinates in segmentation voxel space.
"timestamp": str
Timestamp of the operation.
"user": str
User ID number who performed the operation (as a string).
"""
if isinstance(operation_ids, np.ndarray):
operation_ids = operation_ids.tolist()
endpoint_mapping = self.default_url_mapping
url = self._endpoints["operation_details"].format_map(endpoint_mapping)
query_d = {"operation_ids": operation_ids}
query_str = urlencode(query_d)
url = url + "?" + query_str
r = self.session.get(url)
r.raise_for_status()
return r.json()
[docs] def get_lineage_graph(
self,
root_id,
timestamp_past=None,
timestamp_future=None,
as_nx_graph=False,
exclude_links_to_future=False,
exclude_links_to_past=False,
) -> Union[dict, nx.DiGraph]:
"""
Returns the lineage graph for a root ID, optionally cut off in the past or
the future.
Each change in the chunked graph creates a new root ID for the object after
that change. This function returns a graph of all root IDs for a given object,
tracing the history of the object in terms of merges and splits.
Parameters
----------
root_id : int
Object root ID.
timestamp_past : datetime.datetime or None, optional
Cutoff for the lineage graph backwards in time. By default, None.
timestamp_future : datetime.datetime or None, optional
Cutoff for the lineage graph going forwards in time. By default, None.
as_nx_graph: bool
If True, a NetworkX graph is returned.
exclude_links_to_future: bool
If True, links from nodes before `timestamp_future` to after
`timestamp_future` are removed. If False, the link(s) which has one node
before timestamp and one node after timestamp is kept.
exclude_links_to_past: bool
If True, links from nodes before `timestamp_past` to after `timestamp_past`
are removed. If False, the link(s) which has one node before timestamp and
one node after timestamp is kept.
Returns
-------
dict
Dictionary describing the lineage graph and operations for the root ID. Not
returned if `as_nx_graph` is True. The dictionary contains the following
keys:
"directed" : bool
Whether the graph is directed.
"graph" : dict
Dictionary of graph attributes.
"links" : list of dict
Each element of the list is a dictionary describing an edge in the
lineage graph as "source" and "target" keys.
"multigraph" : bool
Whether the graph is a multigraph.
"nodes" : list of dict
Each element of the list is a dictionary describing a node in the
lineage graph, usually with "id", "timestamp", and "operation_id"
keys.
nx.DiGraph
NetworkX directed graph of the lineage graph. Only returned if `as_nx_graph`
is True.
"""
root_id = root_id_int_list_check(root_id, make_unique=True)
endpoint_mapping = self.default_url_mapping
params = {}
if timestamp_past is not None:
params.update(package_timestamp(timestamp_past, name="timestamp_past"))
if timestamp_future is not None:
params.update(package_timestamp(timestamp_future, name="timestamp_future"))
url = self._endpoints["handle_lineage_graph"].format_map(endpoint_mapping)
data = json.dumps({"root_ids": root_id}, cls=BaseEncoder)
r = handle_response(self.session.post(url, data=data, params=params))
if exclude_links_to_future or exclude_links_to_past:
bad_ids = []
for node in r["nodes"]:
node_ts = datetime.datetime.fromtimestamp(node["timestamp"])
node_ts = node_ts.astimezone(datetime.timezone.utc)
if (
exclude_links_to_past and (node_ts < timestamp_past)
if timestamp_past is not None
else False
):
bad_ids.append(node["id"])
if (
exclude_links_to_future and (node_ts > timestamp_future)
if timestamp_future is not None
else False
):
bad_ids.append(node["id"])
r["nodes"] = [node for node in r["nodes"] if node["id"] not in bad_ids]
r["links"] = [
link
for link in r["links"]
if link["source"] not in bad_ids and link["target"] not in bad_ids
]
if as_nx_graph:
return nx.node_link_graph(r)
else:
return r
[docs] def get_latest_roots(
self, root_id, timestamp=None, timestamp_future=None
) -> np.ndarray:
"""
Returns root IDs that are related to the given `root_id` at a given
timestamp. Can be used to find the "latest" root IDs associated with an object.
Parameters
----------
root_id : int
Object root ID.
timestamp : datetime.datetime or None, optional
Timestamp of where to query IDs from. If None then will assume you want
till now.
timestamp_future : datetime.datetime or None, optional
DEPRECATED name, use `timestamp` instead.
Timestamp to suggest IDs from (note can be in the past relative to the
root). By default, None.
Returns
-------
np.ndarray
1d array with all latest successors.
"""
root_id = root_id_int_list_check(root_id, make_unique=True)
timestamp_root = self.get_root_timestamps(root_id).min()
if timestamp_future is not None:
logger.warning("timestamp_future is deprecated, use timestamp instead")
timestamp = timestamp_future
if timestamp is None:
timestamp = datetime.datetime.now(datetime.timezone.utc)
elif timestamp.tzinfo is None:
timestamp = timestamp.replace(tzinfo=datetime.timezone.utc)
# or if timestamp_root is less than timestamp_future
if (timestamp is None) or (timestamp_root < timestamp):
lineage_graph = self.get_lineage_graph(
root_id,
timestamp_past=timestamp_root,
timestamp_future=timestamp,
exclude_links_to_future=True,
as_nx_graph=True,
)
# then we want the leaves of the tree
out_degree_dict = dict(lineage_graph.out_degree)
nodes = np.array(list(out_degree_dict.keys()))
out_degrees = np.array(list(out_degree_dict.values()))
return nodes[out_degrees == 0]
else:
# then timestamp is in fact in the past
lineage_graph = self.get_lineage_graph(
root_id,
timestamp_future=timestamp_root,
timestamp_past=timestamp,
as_nx_graph=True,
)
in_degree_dict = dict(lineage_graph.in_degree)
nodes = np.array(list(in_degree_dict.keys()))
in_degrees = np.array(list(in_degree_dict.values()))
return nodes[in_degrees == 0]
[docs] def get_original_roots(self, root_id, timestamp_past=None) -> np.ndarray:
"""Returns root IDs that are the latest successors of a given root ID.
Parameters
----------
root_id : int
Object root ID.
timestamp_past : datetime.datetime or None, optional
Cutoff for the search going backwards in time. By default, None.
Returns
-------
np.ndarray
1d array with all latest successors.
"""
root_id = root_id_int_list_check(root_id, make_unique=True)
timestamp_future = self.get_root_timestamps(root_id).max()
lineage_graph = self.get_lineage_graph(
root_id,
timestamp_past=timestamp_past,
timestamp_future=timestamp_future,
as_nx_graph=True,
)
in_degree_dict = dict(lineage_graph.in_degree)
nodes = np.array(list(in_degree_dict.keys()))
in_degrees = np.array(list(in_degree_dict.values()))
return nodes[in_degrees == 0]
[docs] def is_latest_roots(self, root_ids, timestamp=None) -> np.ndarray:
"""Check whether these root IDs are still a root at this timestamp.
Parameters
----------
root_ids : list or array of int
Root IDs to check.
timestamp : datetime.datetime, optional
Timestamp to check whether these IDs are valid root IDs in the chunked
graph. Defaults to None (assumes now).
Returns
-------
np.array of bool
Array of whether these are valid root IDs.
"""
root_ids = root_id_int_list_check(root_ids, make_unique=False)
endpoint_mapping = self.default_url_mapping
url = self._endpoints["is_latest_roots"].format_map(endpoint_mapping)
if timestamp is None:
timestamp = self._default_timestamp
if timestamp is not None:
query_d = package_timestamp(self._process_timestamp(timestamp))
else:
query_d = None
data = {"node_ids": root_ids}
r = handle_response(
self.session.post(
url, data=json.dumps(data, cls=BaseEncoder), params=query_d
)
)
return np.array(r["is_latest"], bool)
[docs] def suggest_latest_roots(
self,
root_id,
timestamp=None,
stop_layer=None,
return_all=False,
return_fraction_overlap=False,
):
"""
Suggest latest roots for a given root id, based on overlap of component
chunk IDs. Note that edits change chunk IDs, and so this effectively measures
the fraction of unchanged chunks at a given chunk layer, which sets the size
scale of chunks. Higher layers are coarser.
Parameters
----------
root_id : int
Root ID of the potentially outdated object.
timestamp : datetime, optional
Datetime at which "latest" roots are being computed, by default None. If
None, the current time is used. Note that this has to be a timestamp after
the creation of the `root_id`.
stop_layer : int, optional
Chunk level at which to compute overlap, by default None.
No value will take the 4th from the top layer, which emphasizes speed and
works well for larger objects. Lower values are slower but more
fine-grained. Values under 2 (i.e. supervoxels) are not recommended except
in extremely fine grained scenarios.
return_all : bool, optional
If True, return all current IDs sorted from most overlap to least, by
default False. If False, only the top is returned.
return_fraction_overlap : bool, optional
If True, return all fractions sorted by most overlap to least, by default
False. If False, only the top value is returned.
"""
curr_ids = self.get_latest_roots(root_id, timestamp=timestamp)
if root_id in curr_ids:
if return_all:
if return_fraction_overlap:
return [root_id], [1]
else:
return [root_id]
else:
if return_fraction_overlap:
return root_id, 1
else:
return root_id
delta_layers = 4
if stop_layer is None:
stop_layer = (
self.segmentation_info.get("graph", {}).get("n_layers", 6)
- delta_layers
)
stop_layer = max(1, stop_layer)
chunks_orig = self.get_leaves(root_id, stop_layer=stop_layer)
while len(chunks_orig) == 0:
stop_layer -= 1
if stop_layer == 1:
raise ValueError(
f"There were no children for root_id={root_id} at level 2, something is wrong with the chunkedgraph"
)
chunks_orig = self.get_leaves(root_id, stop_layer=stop_layer)
chunk_list = np.array(
[
len(
np.intersect1d(
chunks_orig,
self.get_leaves(oid, stop_layer=stop_layer),
assume_unique=True,
)
)
/ len(chunks_orig)
for oid in curr_ids
]
)
order = np.argsort(chunk_list)[::-1]
if not return_all:
order = order[0]
if return_fraction_overlap:
return curr_ids[order], chunk_list[order]
else:
return curr_ids[order]
[docs] def is_valid_nodes(
self, node_ids, start_timestamp=None, end_timestamp=None
) -> np.ndarray:
"""Check whether nodes are valid for given timestamp range.
Valid is defined as existing in the chunked graph. This makes no statement
about these IDs being roots, supervoxel or anything in-between. It also
does not take into account whether a root ID has since been edited.
Parameters
----------
node_ids : list or array of int
Node IDs to check.
start_timestamp : datetime.datetime, optional
Timestamp to check whether these IDs were valid after this timestamp.
Defaults to None (assumes now).
end_timestamp : datetime.datetime, optional
Timestamp to check whether these IDs were valid before this timestamp.
Defaults to None (assumes now).
Returns
-------
np.array of bool
Array of whether these are valid IDs.
"""
node_ids = root_id_int_list_check(node_ids, make_unique=False)
endpoint_mapping = self.default_url_mapping
url = self._endpoints["valid_nodes"].format_map(endpoint_mapping)
if end_timestamp is None:
end_timestamp = self._default_timestamp
if start_timestamp is None:
start_timestamp = datetime.datetime(2000, 1, 1)
if start_timestamp is not None:
query_d = package_timestamp(
self._process_timestamp(start_timestamp), name="start_timestamp"
)
else:
query_d = {}
if end_timestamp is not None:
query_d.update(
package_timestamp(
self._process_timestamp(end_timestamp), name="end_timestamp"
)
)
data = {"node_ids": node_ids}
r = handle_response(
self.session.get(
url, data=json.dumps(data, cls=BaseEncoder), params=query_d
)
)
valid_ids = np.array(r["valid_roots"], np.uint64)
return np.isin(node_ids, valid_ids)
[docs] def get_root_timestamps(self, root_ids) -> np.ndarray:
"""Retrieves timestamps when roots where created.
Parameters
----------
root_ids: Iterable of int
Iterable of root IDs to query.
Returns
-------
np.array of datetime.datetime
Array of timestamps when `root_ids` were created.
"""
root_ids = root_id_int_list_check(root_ids, make_unique=False)
endpoint_mapping = self.default_url_mapping
url = self._endpoints["root_timestamps"].format_map(endpoint_mapping)
data = {"node_ids": root_ids}
r = handle_response(
self.session.post(url, data=json.dumps(data, cls=BaseEncoder))
)
return np.array(
[datetime.datetime.fromtimestamp(ts, pytz.UTC) for ts in r["timestamp"]]
)
[docs] def get_past_ids(
self, root_ids, timestamp_past=None, timestamp_future=None
) -> dict:
"""
For a set of root IDs, get the list of IDs at a past or future time point
that could contain parts of the same object.
Parameters
----------
root_ids : Iterable of int
Iterable of root IDs to query.
timestamp_past : datetime.datetime or None, optional
Time of a point in the past for which to look up root ids. Default is None.
timestamp_future : datetime.datetime or None, optional
Time of a point in the future for which to look up root ids. Not
implemented on the server currently. Default is None.
Returns
-------
dict
Dict with keys "future_id_map" and "past_id_map". Each is a dict whose keys
are the supplied `root_ids` and whose values are the list of related
root IDs at `timestamp_past`/`timestamp_future`.
"""
root_ids = root_id_int_list_check(root_ids, make_unique=True)
endpoint_mapping = self.default_url_mapping
params = {}
if timestamp_past is not None:
params.update(package_timestamp(timestamp_past, name="timestamp_past"))
if timestamp_future is not None:
params.update(package_timestamp(timestamp_future, name="timestamp_future"))
data = {"root_ids": np.array(root_ids, dtype=np.uint64)}
url = self._endpoints["past_id_mapping"].format_map(endpoint_mapping)
r = handle_response(
self.session.get(url, data=json.dumps(data, cls=BaseEncoder), params=params)
)
# Convert id keys as strings to ints
past_keys = list(r["past_id_map"].keys())
for k in past_keys:
dat = r["past_id_map"].pop(k)
r["past_id_map"][int(k)] = dat
fut_keys = list(r["future_id_map"].keys())
for k in fut_keys:
dat = r["future_id_map"].pop(k)
r["future_id_map"][int(k)] = dat
return r
[docs] def get_delta_roots(
self,
timestamp_past: datetime.datetime,
timestamp_future: datetime.datetime = datetime.datetime.now(
datetime.timezone.utc
),
) -> Tuple[np.ndarray, np.ndarray]:
"""
Get the list of roots that have changed between `timetamp_past` and
`timestamp_future`.
Parameters
----------
timestamp_past : datetime.datetime
Past timepoint to query
timestamp_future : datetime.datetime, optional
Future timepoint to query. Defaults to
``datetime.datetime.now(datetime.timezone.utc)``.
Returns
-------
old_roots : np.ndarray of np.int64
Roots that have expired in that interval.
new_roots : np.ndarray of np.int64
Roots that are new in that interval.
"""
endpoint_mapping = self.default_url_mapping
params = package_timestamp(timestamp_past, name="timestamp_past")
params.update(package_timestamp(timestamp_future, name="timestamp_future"))
url = self._endpoints["delta_roots"].format_map(endpoint_mapping)
r = handle_response(self.session.get(url, params=params))
return np.array(r["old_roots"]), np.array(r["new_roots"])
[docs] def get_oldest_timestamp(self) -> datetime.datetime:
"""Get the oldest timestamp in the database.
Returns
-------
datetime.datetime
Oldest timestamp in the database.
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["oldest_timestamp"].format_map(endpoint_mapping)
response = handle_response(self.session.get(url))
return datetime.datetime.fromisoformat(response["iso"])
@property
def cloudvolume_path(self):
return self._endpoints["cloudvolume_path"].format_map(self.default_url_mapping)
@property
def segmentation_info(self):
"""Complete segmentation metadata"""
if self._segmentation_info is None:
url = self._endpoints["info"].format_map(self.default_url_mapping)
response = self.session.get(url)
self._segmentation_info = handle_response(response)
return self._segmentation_info
@property
def base_resolution(self):
"""MIP 0 resolution for voxels assumed by the ChunkedGraph
Returns
-------
list
3-long list of x/y/z voxel dimensions in nm
"""
return self.segmentation_info["scales"][0].get("resolution")
client_mapping = {
1: ChunkedGraphClientV1,
"latest": ChunkedGraphClientV1,
}