Source code for brainpy._src.math.op_register.numba_approach.cpu_translation

# -*- coding: utf-8 -*-

import ctypes

from jax import dtypes, numpy as jnp
from jax.core import ShapedArray
from jax.lib import xla_client

from brainpy._src.dependency_check import import_numba

numba = import_numba(error_if_not_found=False)
ctypes.pythonapi.PyCapsule_New.argtypes = [
  ctypes.c_void_p,  # void* pointer
  ctypes.c_char_p,  # const char *name
  ctypes.c_void_p,  # PyCapsule_Destructor destructor
]
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object

__all__ = [
  '_cpu_translation',
  'compile_cpu_signature_with_numba',
]

if numba is not None:
  from numba import types, carray, cfunc


def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info):
  target_name, inputs, input_shapes, xla_output_shapes = \
    compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info)
  return xla_client.ops.CustomCallWithLayout(
    c,
    target_name,
    operands=inputs,
    operand_shapes_with_layout=input_shapes,
    shape_with_layout=xla_output_shapes,
  )


def _cpu_signature(
    func,
    input_dtypes,
    input_shapes,
    output_dtypes,
    output_shapes,
    multiple_results: bool,
    debug: bool = False
):
  code_scope = dict(
    func_to_call=func,
    input_shapes=input_shapes,
    input_dtypes=input_dtypes,
    output_shapes=output_shapes,
    output_dtypes=output_dtypes,
    carray=carray,
  )

  # inputs
  if len(input_shapes) > 1:
    args_in = [
      f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),'
      for i in range(len(input_shapes))
    ]
    args_in = '(\n    ' + "\n    ".join(args_in) + '\n  )'
  else:
    args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])'

  # outputs
  if multiple_results:
    args_out = [
      f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),'
      for i in range(len(output_shapes))
    ]
    args_out = '(\n    ' + "\n    ".join(args_out) + '\n  )'
  else:
    args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'

  # function body
  code_string = '''
def xla_cpu_custom_call_target(output_ptrs, input_ptrs):
  args_out = {args_out}
  args_in = {args_in}
  func_to_call(args_out, args_in)
    '''.format(args_in=args_in,
               args_out=args_out)
  if debug: print(code_string)
  exec(compile(code_string.strip(), '', 'exec'), code_scope)

  new_f = code_scope['xla_cpu_custom_call_target']
  if multiple_results:
    xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr),
                                  types.CPointer(types.voidptr)))(new_f)
  else:
    xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f)
  target_name = xla_c_rule.native_name.encode("ascii")
  capsule = ctypes.pythonapi.PyCapsule_New(
    xla_c_rule.address,  # A CFFI pointer to a function
    b"xla._CUSTOM_CALL_TARGET",  # A binary string
    None  # PyCapsule object run at destruction
  )
  xla_client.register_custom_call_target(target_name, capsule, "cpu")
  return target_name


[docs] def compile_cpu_signature_with_numba( c, func, abs_eval_fn, multiple_results, inputs: tuple, description: dict = None, ): input_layouts = [c.get_shape(arg) for arg in inputs] info_inputs = [] if description is None: description = dict() for v in description.values(): if isinstance(v, (int, float)): input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) info_inputs.append(xla_client.ops.ConstantLiteral(c, v)) elif isinstance(v, (tuple, list)): v = jnp.asarray(v) input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1)))) info_inputs.append(xla_client.ops.Constant(c, v)) else: raise TypeError input_layouts = tuple(input_layouts) input_dtypes = tuple(shape.element_type() for shape in input_layouts) input_dimensions = tuple(shape.dimensions() for shape in input_layouts) output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) for shape in input_layouts[:len(inputs)]), **description) if isinstance(output_abstract_arrays, ShapedArray): output_abstract_arrays = (output_abstract_arrays,) assert not multiple_results else: assert multiple_results output_shapes = tuple(array.shape for array in output_abstract_arrays) output_dtypes = tuple(array.dtype for array in output_abstract_arrays) output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) target_name = _cpu_signature(func, input_dtypes, input_dimensions, output_dtypes, output_shapes, multiple_results, debug=False) output_layouts = [xla_client.Shape.array_shape(*arg) for arg in zip(output_dtypes, output_shapes, output_layouts)] output_layouts = (xla_client.Shape.tuple_shape(output_layouts) if multiple_results else output_layouts[0]) return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts