from typing import List, Optional, Tuple
import torch
import torch.utils._pytree as pytree
from torch import Tensor
def _pytreeify(cls):
r"""A pytree is Python nested data structure. It is a tree in the sense
that nodes are Python collections (e.g., list, tuple, dict) and the leaves
are Python values.
pytrees are useful for working with nested collections of Tensors. For
example, one can use `tree_map` to map a function over all Tensors inside
some nested collection of Tensors and `tree_unflatten` to get a flat list
of all Tensors inside some nested collection.
"""
assert issubclass(cls, torch.autograd.Function)
# The original functions we will replace with flattened versions:
orig_fw = cls.forward
orig_bw = cls.backward
orig_apply = cls.apply
def new_apply(*inputs):
flat_inputs, struct = pytree.tree_flatten(inputs)
out_struct_holder = []
flat_out = orig_apply(struct, out_struct_holder, *flat_inputs)
assert len(out_struct_holder) == 1
return pytree.tree_unflatten(flat_out, out_struct_holder[0])
def new_forward(ctx, struct, out_struct_holder, *flat_inputs):
inputs = pytree.tree_unflatten(flat_inputs, struct)
out = orig_fw(ctx, *inputs)
flat_out, out_struct = pytree.tree_flatten(out)
ctx._inp_struct = struct
ctx._out_struct = out_struct
out_struct_holder.append(out_struct)
return tuple(flat_out)
def new_backward(ctx, *flat_grad_outputs):
outs_grad = pytree.tree_unflatten(flat_grad_outputs, ctx._out_struct)
if not isinstance(outs_grad, tuple):
outs_grad = (outs_grad,)
grad_inputs = orig_bw(ctx, *outs_grad)
flat_grad_inputs, grad_inputs_struct = pytree.tree_flatten(grad_inputs)
return (None, None) + tuple(flat_grad_inputs)
cls.apply = new_apply
cls.forward = new_forward
cls.backward = new_backward
return cls
@_pytreeify
class GroupedMatmul(torch.autograd.Function):
@staticmethod
def forward(ctx, args: Tuple[Tensor]) -> Tuple[Tensor]:
ctx.save_for_backward(*args)
inputs: List[Tensor] = [x for x in args[: int(len(args) / 2)]]
others: List[Tensor] = [other for other in args[int(len(args) / 2) :]]
outs = torch.ops.pyg.grouped_matmul(inputs, others)
# NOTE Autograd doesnt set `out[i].requires_grad = True` automatically
for x, other, out in zip(inputs, others, outs):
if x.requires_grad or other.requires_grad:
out.requires_grad = True
return tuple(outs)
@staticmethod
def backward(ctx, *outs_grad: Tuple[Tensor]) -> Tuple[Tensor]:
args = ctx.saved_tensors
inputs: List[Tensor] = [x for x in args[: int(len(outs_grad))]]
others: List[Tensor] = [other for other in args[int(len(outs_grad)) :]]
inputs_grad = []
if any([x.requires_grad for x in inputs]):
others = [other.t() for other in others]
inputs_grad = torch.ops.pyg.grouped_matmul(outs_grad, others)
else:
inputs_grad = [None for i in range(len(outs_grad))]
others_grad = []
if any([other.requires_grad for other in others]):
inputs = [x.t() for x in inputs]
others_grad = torch.ops.pyg.grouped_matmul(inputs, outs_grad)
else:
others_grad = [None for i in range(len(outs_grad))]
return tuple(inputs_grad + others_grad)
[docs]def grouped_matmul(
inputs: List[Tensor],
others: List[Tensor],
biases: Optional[List[Tensor]] = None,
) -> List[Tensor]:
r"""Performs dense-dense matrix multiplication according to groups,
utilizing dedicated kernels that effectively parallelize over groups.
.. code-block:: python
inputs = [torch.randn(5, 16), torch.randn(3, 32)]
others = [torch.randn(16, 32), torch.randn(32, 64)]
outs = pyg_lib.ops.grouped_matmul(inputs, others)
assert len(outs) == 2
assert outs[0].size() == (5, 32)
assert outs[0] == inputs[0] @ others[0]
assert outs[1].size() == (3, 64)
assert outs[1] == inputs[1] @ others[1]
Args:
inputs: List of left operand 2D matrices of shapes :obj:`[N_i, K_i]`.
others: List of right operand 2D matrices of shapes :obj:`[K_i, M_i]`.
biases: Optional bias terms to apply for each element.
Returns:
List of 2D output matrices of shapes :obj:`[N_i, M_i]`.
"""
# Combine inputs into a single tuple for autograd:
outs = list(GroupedMatmul.apply(tuple(inputs + others)))
if biases is not None:
for i in range(len(biases)):
outs[i] = outs[i] + biases[i]
return outs
[docs]def segment_matmul(
inputs: Tensor,
ptr: Tensor,
other: Tensor,
bias: Optional[Tensor] = None,
) -> Tensor:
r"""Performs dense-dense matrix multiplication according to segments along
the first dimension of :obj:`inputs` as given by :obj:`ptr`, utilizing
dedicated kernels that effectively parallelize over groups.
.. code-block:: python
inputs = torch.randn(8, 16)
ptr = torch.tensor([0, 5, 8])
other = torch.randn(2, 16, 32)
out = pyg_lib.ops.segment_matmul(inputs, ptr, other)
assert out.size() == (8, 32)
assert out[0:5] == inputs[0:5] @ other[0]
assert out[5:8] == inputs[5:8] @ other[1]
Args:
inputs: The left operand 2D matrix of shape :obj:`[N, K]`.
ptr: Compressed vector of shape :obj:`[B + 1]`, holding the boundaries
of segments. For best performance, given as a CPU tensor.
other: The right operand 3D tensor of shape :obj:`[B, K, M]`.
bias: The bias term of shape :obj:`[B, M]`.
Returns:
The 2D output matrix of shape :obj:`[N, M]`.
"""
out = torch.ops.pyg.segment_matmul(inputs, ptr, other)
if bias is not None:
for i in range(ptr.numel() - 1):
out[ptr[i] : ptr[i + 1]] += bias[i]
return out
[docs]def sampled_add(
left: Tensor,
right: Tensor,
left_index: Optional[Tensor] = None,
right_index: Optional[Tensor] = None,
) -> Tensor:
r"""Performs a sampled **addition** of :obj:`left` and :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] +
\textrm{right}[\textrm{right_index}]
This operation fuses the indexing and addition operation together, thus
being more runtime and memory-efficient.
Args:
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, 'add')
return out
[docs]def sampled_sub(
left: Tensor,
right: Tensor,
left_index: Optional[Tensor] = None,
right_index: Optional[Tensor] = None,
) -> Tensor:
r"""Performs a sampled **subtraction** of :obj:`left` by :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] -
\textrm{right}[\textrm{right_index}]
This operation fuses the indexing and subtraction operation together, thus
being more runtime and memory-efficient.
Args:
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, 'sub')
return out
[docs]def sampled_mul(
left: Tensor,
right: Tensor,
left_index: Optional[Tensor] = None,
right_index: Optional[Tensor] = None,
) -> Tensor:
r"""Performs a sampled **multiplication** of :obj:`left` and :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] *
\textrm{right}[\textrm{right_index}]
This operation fuses the indexing and multiplication operation together,
thus being more runtime and memory-efficient.
Args:
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, 'mul')
return out
[docs]def sampled_div(
left: Tensor,
right: Tensor,
left_index: Optional[Tensor] = None,
right_index: Optional[Tensor] = None,
) -> Tensor:
r"""Performs a sampled **division** of :obj:`left` by :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] /
\textrm{right}[\textrm{right_index}]
This operation fuses the indexing and division operation together, thus
being more runtime and memory-efficient.
Args:
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, 'div')
return out
[docs]def index_sort(
inputs: Tensor,
max_value: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
r"""Sorts the elements of the :obj:`inputs` tensor in ascending order.
It is expected that :obj:`inputs` is one-dimensional and that it only
contains positive integer values. If :obj:`max_value` is given, it can be
used by the underlying algorithm for better performance.
.. note::
This operation is optimized only for tensors associated with the CPU
device.
Args:
inputs: A vector with positive integer values.
max_value: The maximum value stored inside :obj:`inputs`. This value
can be an estimation, but needs to be greater than or equal to the
real maximum.
Returns:
A tuple containing sorted values and indices of the elements in the
original :obj:`input` tensor.
"""
if not inputs.is_cpu:
return torch.sort(inputs)
return torch.ops.pyg.index_sort(inputs, max_value)
[docs]def softmax_csr(
src: Tensor,
ptr: Tensor,
dim: int = 0,
) -> Tensor:
r"""Computes a sparsely evaluated softmax.
Given a value tensor :attr:`src`, this function first groups the values
along the given dimension :attr:`dim`, based on the indices specified via
:attr:`ptr`, and then proceeds to compute the softmax individually for
each group.
Examples:
>>> src = torch.randn(4, 4)
>>> ptr = torch.tensor([0, 4])
>>> softmax(src, ptr)
tensor([[0.0157, 0.0984, 0.1250, 0.4523],
[0.1453, 0.2591, 0.5907, 0.2410],
[0.0598, 0.2923, 0.1206, 0.0921],
[0.7792, 0.3502, 0.1638, 0.2145]])
Args:
src: The source tensor.
ptr: Groups defined by CSR representation.
dim: The dimension in which to normalize.
"""
dim = dim + src.dim() if dim < 0 else dim
return torch.ops.pyg.softmax_csr(src, ptr, dim)
[docs]def scatter_sum(
src: Tensor,
index: Tensor,
dim: int = -1,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
) -> Tensor:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` at the
indices specified in the :obj:`index` tensor along a given axis
:obj:`dim`, using ``sum`` as the reduction.
If :obj:`out` is not given, a new tensor is allocated and zero-initialized.
If :obj:`out` is given, values are **accumulated** into it (no zero-init).
Args:
src: The source tensor.
index: The indices of elements to scatter.
dim: The axis along which to index.
out: The destination tensor.
dim_size: If :obj:`out` is not given, automatically create output with
size :obj:`dim_size` at dimension :obj:`dim`. If :obj:`dim_size` is
also :obj:`None`, a minimal sized output is returned.
Returns:
The reduced tensor.
"""
return torch.ops.pyg.scatter_sum(src, index, dim, out, dim_size)
scatter_add = scatter_sum
[docs]def scatter_mul(
src: Tensor,
index: Tensor,
dim: int = -1,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
) -> Tensor:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` at the
indices specified in the :obj:`index` tensor along a given axis
:obj:`dim`, using ``mul`` as the reduction.
If :obj:`out` is not given, a new tensor is allocated and initialized to
ones (the multiplicative identity). If :obj:`out` is given, values are
**multiplied** into it (no ones-init).
Args:
src: The source tensor.
index: The indices of elements to scatter.
dim: The axis along which to index.
out: The destination tensor.
dim_size: If :obj:`out` is not given, automatically create output with
size :obj:`dim_size` at dimension :obj:`dim`.
Returns:
The reduced tensor.
"""
return torch.ops.pyg.scatter_mul(src, index, dim, out, dim_size)
[docs]def scatter_mean(
src: Tensor,
index: Tensor,
dim: int = -1,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
) -> Tensor:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` at the
indices specified in the :obj:`index` tensor along a given axis
:obj:`dim`, using ``mean`` as the reduction. Empty buckets yield zero.
For floating-point inputs, division is exact. For integer inputs,
floor-division is used (matching upstream ``pytorch_scatter``).
Args:
src: The source tensor.
index: The indices of elements to scatter.
dim: The axis along which to index.
out: The destination tensor.
dim_size: If :obj:`out` is not given, automatically create output with
size :obj:`dim_size` at dimension :obj:`dim`.
Returns:
The reduced tensor.
"""
return torch.ops.pyg.scatter_mean(src, index, dim, out, dim_size)
[docs]def scatter_min(
src: Tensor,
index: Tensor,
dim: int = -1,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` at the
indices specified in the :obj:`index` tensor along a given axis
:obj:`dim`, using ``min`` as the reduction.
Returns a tuple ``(values, argindex)`` where ``argindex[i]`` is the index
along ``dim`` of the source element that contributed to ``values[i]``.
Empty buckets yield value ``0`` and argindex ``src.size(dim)`` (sentinel).
Only ``values`` is differentiable.
Args:
src: The source tensor.
index: The indices of elements to scatter.
dim: The axis along which to index.
out: The destination tensor for the values.
dim_size: If :obj:`out` is not given, automatically create output with
size :obj:`dim_size` at dimension :obj:`dim`.
Returns:
Tuple ``(values, argindex)``.
"""
return torch.ops.pyg.scatter_min(src, index, dim, out, dim_size)
[docs]def scatter_max(
src: Tensor,
index: Tensor,
dim: int = -1,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` at the
indices specified in the :obj:`index` tensor along a given axis
:obj:`dim`, using ``max`` as the reduction.
Returns a tuple ``(values, argindex)`` where ``argindex[i]`` is the index
along ``dim`` of the source element that contributed to ``values[i]``.
Empty buckets yield value ``0`` and argindex ``src.size(dim)`` (sentinel).
Only ``values`` is differentiable.
Args:
src: The source tensor.
index: The indices of elements to scatter.
dim: The axis along which to index.
out: The destination tensor for the values.
dim_size: If :obj:`out` is not given, automatically create output with
size :obj:`dim_size` at dimension :obj:`dim`.
Returns:
Tuple ``(values, argindex)``.
"""
return torch.ops.pyg.scatter_max(src, index, dim, out, dim_size)
[docs]def segment_sum_coo(
src: Tensor,
index: Tensor,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
) -> Tensor:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` according
to a sorted COO :obj:`index`, using ``sum`` as the reduction.
Unlike :func:`scatter_sum`, the reduction axis is always
``index.dim() - 1`` (so no :obj:`dim` argument is exposed). The
:obj:`index` must be sorted in ascending order along that axis.
If :obj:`out` is not given, a new zero-initialized tensor is allocated.
If :obj:`out` is given, values are **accumulated** into it (no zero-init).
Args:
src: The source tensor.
index: Sorted COO indices.
out: The destination tensor.
dim_size: If :obj:`out` is not given, the output size along the
reduction axis. Auto-inferred from :obj:`index.max() + 1` if
:obj:`None`.
Returns:
The reduced tensor.
"""
return torch.ops.pyg.segment_sum_coo(src, index, out, dim_size)
segment_add_coo = segment_sum_coo
[docs]def segment_mean_coo(
src: Tensor,
index: Tensor,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
) -> Tensor:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` according
to a sorted COO :obj:`index`, using ``mean`` as the reduction. Empty
buckets yield zero.
The reduction axis is ``index.dim() - 1``. For integer dtypes,
floor-division is used.
Args:
src: The source tensor.
index: Sorted COO indices.
out: The destination tensor.
dim_size: If :obj:`out` is not given, the output size along the
reduction axis. Auto-inferred from :obj:`index.max() + 1` if
:obj:`None`.
Returns:
The reduced tensor.
"""
return torch.ops.pyg.segment_mean_coo(src, index, out, dim_size)
[docs]def segment_min_coo(
src: Tensor,
index: Tensor,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` according
to a sorted COO :obj:`index`, using ``min`` as the reduction.
Returns a tuple ``(values, argindex)``. The reduction axis is
``index.dim() - 1``. Empty buckets yield value ``0`` and argindex
``src.size(dim)`` (sentinel). Only ``values`` is differentiable.
Args:
src: The source tensor.
index: Sorted COO indices.
out: The destination tensor for the values.
dim_size: If :obj:`out` is not given, the output size along the
reduction axis.
Returns:
Tuple ``(values, argindex)``.
"""
return torch.ops.pyg.segment_min_coo(src, index, out, dim_size)
[docs]def segment_max_coo(
src: Tensor,
index: Tensor,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` according
to a sorted COO :obj:`index`, using ``max`` as the reduction.
Returns a tuple ``(values, argindex)``. The reduction axis is
``index.dim() - 1``. Empty buckets yield value ``0`` and argindex
``src.size(dim)`` (sentinel). Only ``values`` is differentiable.
Args:
src: The source tensor.
index: Sorted COO indices.
out: The destination tensor for the values.
dim_size: If :obj:`out` is not given, the output size along the
reduction axis.
Returns:
Tuple ``(values, argindex)``.
"""
return torch.ops.pyg.segment_max_coo(src, index, out, dim_size)
[docs]def gather_coo(
src: Tensor,
index: Tensor,
out: Optional[Tensor] = None,
) -> Tensor:
r"""Gathers values from :obj:`src` at positions specified by :obj:`index`
along axis ``index.dim() - 1``. This is the symmetric inverse of
:func:`segment_sum_coo`.
Args:
src: The source tensor.
index: The COO indices to gather.
out: The destination tensor (overwritten if given).
Returns:
``out[i] = src[index[i]]`` along the gather axis.
"""
return torch.ops.pyg.gather_coo(src, index, out)
[docs]def segment_sum_csr(
src: Tensor,
indptr: Tensor,
out: Optional[Tensor] = None,
) -> Tensor:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` according
to a CSR :obj:`indptr`, using ``sum`` as the reduction.
The reduction axis is ``indptr.dim() - 1`` and the output size along that
axis is ``indptr.size(-1) - 1`` (no explicit ``dim_size`` arg).
If :obj:`out` is not given, a new zero-initialized tensor is allocated.
If :obj:`out` is given, values are **accumulated** into it.
Args:
src: The source tensor.
indptr: The CSR row pointer of shape :obj:`[..., R+1]`.
out: The destination tensor.
Returns:
The reduced tensor.
"""
return torch.ops.pyg.segment_sum_csr(src, indptr, out)
segment_add_csr = segment_sum_csr
[docs]def segment_mean_csr(
src: Tensor,
indptr: Tensor,
out: Optional[Tensor] = None,
) -> Tensor:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` according
to a CSR :obj:`indptr`, using ``mean`` as the reduction. Empty rows yield
zero.
Args:
src: The source tensor.
indptr: The CSR row pointer of shape :obj:`[..., R+1]`.
out: The destination tensor.
Returns:
The reduced tensor.
"""
return torch.ops.pyg.segment_mean_csr(src, indptr, out)
[docs]def segment_min_csr(
src: Tensor,
indptr: Tensor,
out: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` according
to a CSR :obj:`indptr`, using ``min`` as the reduction.
Returns a tuple ``(values, argindex)``. Empty rows yield value ``0`` and
argindex ``src.size(dim)`` (sentinel). Only ``values`` is differentiable.
Args:
src: The source tensor.
indptr: The CSR row pointer of shape :obj:`[..., R+1]`.
out: The destination tensor for the values.
Returns:
Tuple ``(values, argindex)``.
"""
return torch.ops.pyg.segment_min_csr(src, indptr, out)
[docs]def segment_max_csr(
src: Tensor,
indptr: Tensor,
out: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
r"""Reduces all values from the :obj:`src` tensor into :obj:`out` according
to a CSR :obj:`indptr`, using ``max`` as the reduction.
Returns a tuple ``(values, argindex)``. Empty rows yield value ``0`` and
argindex ``src.size(dim)`` (sentinel). Only ``values`` is differentiable.
Args:
src: The source tensor.
indptr: The CSR row pointer of shape :obj:`[..., R+1]`.
out: The destination tensor for the values.
Returns:
Tuple ``(values, argindex)``.
"""
return torch.ops.pyg.segment_max_csr(src, indptr, out)
[docs]def gather_csr(
src: Tensor,
indptr: Tensor,
out: Optional[Tensor] = None,
) -> Tensor:
r"""Gathers values from :obj:`src` to all positions encoded by the CSR
:obj:`indptr`: for each row ``r``, broadcasts ``src[r]`` to all output
positions in ``[indptr[r], indptr[r+1])``. Symmetric inverse of
:func:`segment_sum_csr`.
Args:
src: The source tensor.
indptr: The CSR row pointer of shape :obj:`[..., R+1]`.
out: The destination tensor (overwritten if given).
Returns:
Broadcast tensor of shape ``src.shape`` with ``indptr.dim()-1`` axis
replaced by ``indptr[-1]``.
"""
return torch.ops.pyg.gather_csr(src, indptr, out)
def _broadcast(src: Tensor, other: Tensor, dim: int) -> Tensor:
r"""Broadcasts a 1-D :obj:`src` to the shape of :obj:`other` along
:obj:`dim`. Used by the Python composites that need to lift a 1-D
per-bucket result back to per-element positions for ``gather`` /
arithmetic broadcasting.
Ports ``torch_scatter/utils.py:broadcast``.
"""
if src.dim() == 1:
for _ in range(dim if dim >= 0 else other.dim() + dim):
src = src.unsqueeze(0)
while src.dim() < other.dim():
src = src.unsqueeze(-1)
return src.expand_as(other)
[docs]def scatter(
src: Tensor,
index: Tensor,
dim: int = -1,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = 'sum',
) -> Tensor:
r"""Polymorphic scatter dispatcher.
Routes to the typed scatter op based on :obj:`reduce`:
``"sum"``/``"add"`` -> :func:`scatter_sum`,
``"mul"`` -> :func:`scatter_mul`, ``"mean"`` -> :func:`scatter_mean`,
``"min"`` -> :func:`scatter_min`, ``"max"`` -> :func:`scatter_max`.
Min/max return only the values.
"""
if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size)
if reduce == 'mul':
return scatter_mul(src, index, dim, out, dim_size)
if reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size)
if reduce == 'min':
return scatter_min(src, index, dim, out, dim_size)[0]
if reduce == 'max':
return scatter_max(src, index, dim, out, dim_size)[0]
raise ValueError(f'Unknown reduce: {reduce!r}')
[docs]def segment_coo(
src: Tensor,
index: Tensor,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = 'sum',
) -> Tensor:
r"""Polymorphic COO segment dispatcher.
Routes by :obj:`reduce` to the typed ``segment_*_coo`` op.
Min/max return only the value tensor.
"""
if reduce == 'sum' or reduce == 'add':
return segment_sum_coo(src, index, out, dim_size)
if reduce == 'mean':
return segment_mean_coo(src, index, out, dim_size)
if reduce == 'min':
return segment_min_coo(src, index, out, dim_size)[0]
if reduce == 'max':
return segment_max_coo(src, index, out, dim_size)[0]
raise ValueError(f'Unknown reduce: {reduce!r}')
[docs]def segment_csr(
src: Tensor,
indptr: Tensor,
out: Optional[Tensor] = None,
reduce: str = 'sum',
) -> Tensor:
r"""Polymorphic CSR segment dispatcher.
Routes by :obj:`reduce` to the typed ``segment_*_csr`` op.
Min/max return only the value tensor.
"""
if reduce == 'sum' or reduce == 'add':
return segment_sum_csr(src, indptr, out)
if reduce == 'mean':
return segment_mean_csr(src, indptr, out)
if reduce == 'min':
return segment_min_csr(src, indptr, out)[0]
if reduce == 'max':
return segment_max_csr(src, indptr, out)[0]
raise ValueError(f'Unknown reduce: {reduce!r}')
[docs]def scatter_softmax(
src: Tensor,
index: Tensor,
dim: int = -1,
dim_size: Optional[int] = None,
) -> Tensor:
r"""Per-bucket softmax composite. Float inputs only."""
if not src.is_floating_point():
raise ValueError(
'scatter_softmax requires a floating-point src tensor (got '
f'{src.dtype})',
)
max_per_idx = scatter_max(src, index, dim, dim_size=dim_size)[0]
max_per_src = max_per_idx.gather(dim, _broadcast(index, src, dim))
recentered_exp = (src - max_per_src).exp()
sum_per_idx = scatter_sum(recentered_exp, index, dim, dim_size=dim_size)
sum_per_src = sum_per_idx.gather(dim, _broadcast(index, src, dim))
return recentered_exp / sum_per_src
[docs]def scatter_log_softmax(
src: Tensor,
index: Tensor,
dim: int = -1,
dim_size: Optional[int] = None,
eps: float = 1e-12,
) -> Tensor:
r"""Per-bucket log-softmax composite. Float inputs only."""
if not src.is_floating_point():
raise ValueError(
'scatter_log_softmax requires a floating-point src tensor (got '
f'{src.dtype})',
)
max_per_idx = scatter_max(src, index, dim, dim_size=dim_size)[0]
max_per_src = max_per_idx.gather(dim, _broadcast(index, src, dim))
recentered = src - max_per_src
sum_per_idx = scatter_sum(recentered.exp(), index, dim, dim_size=dim_size)
sum_per_src = sum_per_idx.gather(dim, _broadcast(index, src, dim))
return recentered - torch.log(sum_per_src + eps)
[docs]def scatter_std(
src: Tensor,
index: Tensor,
dim: int = -1,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
unbiased: bool = True,
) -> Tensor:
r"""Per-bucket standard deviation composite. Float inputs only.
Ports ``torch_scatter.composite.scatter_std``: uses :func:`scatter_sum`
for both passes (avoids integer floor-division coupling that would
appear if :func:`scatter_mean` were used). Applies Bessel's correction
``N / (N - 1)`` when :obj:`unbiased` is :obj:`True`.
"""
if not src.is_floating_point():
raise ValueError(
'scatter_std requires a floating-point src tensor (got '
f'{src.dtype})',
)
if out is not None:
dim_size = out.size(dim)
bcast_index = _broadcast(index, src, dim)
ones = torch.ones_like(src)
count = scatter_sum(ones, bcast_index, dim, dim_size=dim_size)
sum_per_idx = scatter_sum(src, bcast_index, dim, dim_size=dim_size)
count_safe = count.clamp(min=1)
mean = sum_per_idx / count_safe
var = src - mean.gather(dim, bcast_index)
var = var * var
result = scatter_sum(var, bcast_index, dim, out, dim_size)
if unbiased:
denom = (count - 1).clamp(min=1)
else:
denom = count_safe
result = result / denom
return result.sqrt()
[docs]def scatter_logsumexp(
src: Tensor,
index: Tensor,
dim: int = -1,
out: Optional[Tensor] = None,
dim_size: Optional[int] = None,
eps: float = 1e-12,
) -> Tensor:
r"""Per-bucket log-sum-exp composite. Float inputs only.
Numerically stable: recenter by the per-bucket max before exponentiating.
When :obj:`out` is supplied, positions that ended up non-finite (empty
buckets that produced ``-inf``) are restored from the caller-supplied
``out`` to preserve any caller-provided initial values.
"""
if not src.is_floating_point():
raise ValueError(
'scatter_logsumexp requires a floating-point src tensor (got '
f'{src.dtype})',
)
if out is not None:
dim_size = out.size(dim)
# Allocate `max_value_per_index` filled with -inf so empty buckets stay
# at -inf and get masked back to 0 below.
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max().item()) + 1
max_value_per_index = torch.full(
size,
float('-inf'),
dtype=src.dtype,
device=src.device,
)
scatter_max(src, index, dim, max_value_per_index, dim_size)
# `gather` only reads positions that appear in `index`, so empty-bucket
# slots in `max_value_per_index` (still -inf) are never read here.
max_per_src = max_value_per_index.gather(dim, _broadcast(index, src, dim))
recentered = src - max_per_src
# If src had inf or both src and max are -inf, recentered may be NaN.
recentered = torch.where(
torch.isnan(recentered),
torch.full_like(recentered, float('-inf')),
recentered,
)
sum_per_idx = scatter_sum(recentered.exp(), index, dim, dim_size=dim_size)
result = max_value_per_index + (sum_per_idx + eps).log()
if out is None:
# Empty buckets had max=-inf and sum=0, producing -inf + log(eps) =
# -inf. Map any non-finite output to 0 to match the upstream contract.
return result.nan_to_num(nan=0.0, posinf=0.0, neginf=0.0)
# `out=`: restore caller's value at positions that produced non-finite
# results (e.g. empty buckets).
orig_out = out.clone()
mask = ~torch.isfinite(result)
out.copy_(torch.where(mask, orig_out, result))
return out
[docs]def spline_basis(
pseudo: Tensor,
kernel_size: Tensor,
is_open_spline: Tensor,
degree: int = 1,
) -> Tuple[Tensor, Tensor]:
r"""Computes the B-spline basis functions.
Args:
pseudo: Pseudo-coordinates of shape :obj:`[E, D]`.
kernel_size: Kernel size in each dimension of shape :obj:`[D]`.
is_open_spline: Whether to use open B-splines of shape :obj:`[D]`.
degree: B-spline degree (1, 2, or 3).
Returns:
Basis values of shape :obj:`[E, S]` and weight indices of shape
:obj:`[E, S]`.
"""
return torch.ops.pyg.spline_basis(
pseudo,
kernel_size,
is_open_spline,
degree,
)
[docs]def spline_weighting(
x: Tensor,
weight: Tensor,
basis: Tensor,
weight_index: Tensor,
) -> Tensor:
r"""Computes the spline weighting of input features.
Args:
x: Input features of shape :obj:`[E, M_in]`.
weight: Weight tensor of shape :obj:`[K, M_in, M_out]`.
basis: B-spline basis values of shape :obj:`[E, S]`.
weight_index: Weight indices of shape :obj:`[E, S]`.
Returns:
Output features of shape :obj:`[E, M_out]`.
"""
return torch.ops.pyg.spline_weighting(x, weight, basis, weight_index)
[docs]def grid_cluster(
pos: Tensor,
size: Tensor,
start: Optional[Tensor] = None,
end: Optional[Tensor] = None,
) -> Tensor:
r"""Clusters all points in :obj:`pos` into voxels of size :obj:`size`.
Each point is assigned a cluster index based on which voxel it falls into.
The voxel grid is defined by the :obj:`size` parameter and optionally
bounded by :obj:`start` and :obj:`end`.
Args:
pos: Point positions of shape :obj:`[N, D]`.
size: Voxel size in each dimension of shape :obj:`[D]`.
start: Start of the voxel grid in each dimension of shape :obj:`[D]`.
If :obj:`None`, uses the minimum of :obj:`pos`.
end: End of the voxel grid in each dimension of shape :obj:`[D]`.
If :obj:`None`, uses the maximum of :obj:`pos`.
Returns:
Cluster index for each point of shape :obj:`[N]`.
"""
return torch.ops.pyg.grid_cluster(pos, size, start, end)
[docs]def fps(
src: Tensor,
ptr: Tensor,
ratio: float = 0.5,
random_start: bool = True,
) -> Tensor:
r"""Performs greedy farthest point sampling.
Starting from a random point (or the first point), iteratively selects
the point that is farthest from the already selected set.
Args:
src: Point positions of shape :obj:`[N, D]`.
ptr: Batch boundaries as a CSR pointer of shape :obj:`[B + 1]`.
ratio: Fraction of points to sample from each batch (in :obj:`(0, 1]`).
random_start: If :obj:`True`, starts from a random point.
Returns:
Indices of the sampled points of shape :obj:`[M]`.
"""
return torch.ops.pyg.fps(src, ptr, ratio, random_start)
[docs]def knn(
x: Tensor,
y: Tensor,
k: int = 1,
ptr_x: Optional[Tensor] = None,
ptr_y: Optional[Tensor] = None,
cosine: bool = False,
num_workers: int = 1,
) -> Tensor:
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
Args:
x: Reference points of shape :obj:`[N, D]`.
y: Query points of shape :obj:`[M, D]`.
k: Number of nearest neighbors.
ptr_x: Batch boundaries for :obj:`x` as a CSR pointer.
ptr_y: Batch boundaries for :obj:`y` as a CSR pointer.
cosine: If :obj:`True`, uses cosine distance (CUDA only).
num_workers: Number of workers (unused, for API compat).
Returns:
Edge indices of shape :obj:`[2, M*k]` where row 0 is query indices
and row 1 is reference indices.
"""
return torch.ops.pyg.knn(x, y, ptr_x, ptr_y, k, cosine, num_workers)
[docs]def radius(
x: Tensor,
y: Tensor,
r: float = 1.0,
ptr_x: Optional[Tensor] = None,
ptr_y: Optional[Tensor] = None,
max_num_neighbors: int = 32,
num_workers: int = 1,
ignore_same_index: bool = False,
) -> Tensor:
r"""Finds all points in :obj:`x` within distance :obj:`r` of points in
:obj:`y`.
Args:
x: Reference points of shape :obj:`[N, D]`.
y: Query points of shape :obj:`[M, D]`.
r: Radius.
ptr_x: Batch boundaries for :obj:`x` as a CSR pointer.
ptr_y: Batch boundaries for :obj:`y` as a CSR pointer.
max_num_neighbors: Maximum number of neighbors per query point.
num_workers: Number of workers (unused, for API compat).
ignore_same_index: If :obj:`True`, ignores pairs with same index.
Returns:
Edge indices of shape :obj:`[2, E]` where row 0 is query indices
and row 1 is reference indices.
"""
return torch.ops.pyg.radius(
x,
y,
ptr_x,
ptr_y,
r,
max_num_neighbors,
num_workers,
ignore_same_index,
)
[docs]def nearest(
x: Tensor,
y: Tensor,
ptr_x: Optional[Tensor] = None,
ptr_y: Optional[Tensor] = None,
) -> Tensor:
r"""Finds the nearest point in :obj:`y` for each point in :obj:`x`.
Args:
x: Query points of shape :obj:`[N, D]`.
y: Reference points of shape :obj:`[M, D]`.
ptr_x: Batch boundaries for :obj:`x` as a CSR pointer.
ptr_y: Batch boundaries for :obj:`y` as a CSR pointer.
Returns:
Index tensor of shape :obj:`[N]` with the index of the nearest
point in :obj:`y` for each point in :obj:`x`.
"""
return torch.ops.pyg.nearest(x, y, ptr_x, ptr_y)
[docs]def graclus_cluster(
rowptr: Tensor,
col: Tensor,
weight: Optional[Tensor] = None,
) -> Tensor:
r"""Computes a greedy graph clustering via the Graclus algorithm.
Nodes are matched greedily in random order. The cluster ID for a
matched pair (u, v) is :obj:`min(u, v)`. Unmatched nodes are assigned
their own index as cluster ID.
Args:
rowptr: CSR row pointer of shape :obj:`[N + 1]`.
col: Column indices of shape :obj:`[E]`.
weight: Optional edge weights of shape :obj:`[E]`.
Returns:
Cluster assignment of shape :obj:`[N]`.
"""
return torch.ops.pyg.graclus_cluster(rowptr, col, weight)
[docs]def edge_sample(
start: Tensor,
rowptr: Tensor,
count: int = 0,
factor: float = 1.0,
) -> Tensor:
r"""Samples edges incident to the given start nodes.
For each start node, samples up to :obj:`count` edges. If
:obj:`count < 1`, samples :obj:`ceil(factor * degree)` edges instead.
Args:
start: Start node indices of shape :obj:`[S]`.
rowptr: CSR row pointer of shape :obj:`[N + 1]`.
count: Fixed number of edges to sample per node. If :obj:`< 1`,
uses :obj:`factor` instead.
factor: Fraction of edges to sample when :obj:`count < 1`.
Returns:
Sampled edge indices (into the edge list).
"""
return torch.ops.pyg.edge_sample(start, rowptr, count, factor)
__all__ = [
'grouped_matmul',
'segment_matmul',
'sampled_add',
'sampled_sub',
'sampled_mul',
'sampled_div',
'index_sort',
'softmax_csr',
'scatter_sum',
'scatter_add',
'scatter_mul',
'scatter_mean',
'scatter_min',
'scatter_max',
'segment_sum_coo',
'segment_add_coo',
'segment_mean_coo',
'segment_min_coo',
'segment_max_coo',
'gather_coo',
'segment_sum_csr',
'segment_add_csr',
'segment_mean_csr',
'segment_min_csr',
'segment_max_csr',
'gather_csr',
'scatter',
'segment_coo',
'segment_csr',
'scatter_softmax',
'scatter_log_softmax',
'scatter_std',
'scatter_logsumexp',
'spline_basis',
'spline_weighting',
'grid_cluster',
'fps',
'knn',
'radius',
'nearest',
'graclus_cluster',
'edge_sample',
]