Source code for brainpy._src.math.op_register.taichi_aot_based

import contextlib
import hashlib
import inspect
import io
import os
import pathlib
import platform
import re
import shutil
from functools import partial, reduce
from typing import Any, Sequence, Union

import jax.core
import numpy as np
from jax.interpreters import xla, mlir
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call

from brainpy._src.dependency_check import (import_taichi,
from brainpy.errors import PackageMissingError
from .utils import _shape_to_layout

taichi_cache_path = None

# --- UTILS ###

# get the path of home directory on Linux, Windows, Mac
def get_home_dir():
  return str(pathlib.Path.home())

# encode a string with md5
def encode_md5(source: str) -> str:
  # create md5 object
  md5 = hashlib.md5()

  # encode source
  source_encode = source.encode(encoding='utf-8')

  # update md5 object

  return md5.hexdigest()

# check kernels count
[docs] def count_taichi_aot_kernels() -> int: """ Count the number of AOT compiled kernels. Returns ------- kernels_count: int The number of AOT compiled kernels. """ if not os.path.exists(kernels_aot_path): return 0 kernels_count = 0 dir1 = os.listdir(kernels_aot_path) for i in dir1: dir2 = os.listdir(os.path.join(kernels_aot_path, i)) kernels_count += len(dir2) return kernels_count
[docs] def clear_taichi_aot_caches(kernels: Union[str, Sequence[str]] = None): """ Clean the cache of the AOT compiled kernels. Parameters ---------- kernels: str or list of str The name of the kernel to be cleaned. If None, all the kernels will be cleaned. """ if kernels is None: global taichi_cache_path if taichi_cache_path is None: from taichi._lib.utils import import_ti_python_core taichi_cache_path = import_ti_python_core().get_repo_dir() # clean taichi cache if os.path.exists(taichi_cache_path): shutil.rmtree(taichi_cache_path) # clean brainpy-taichi AOT cache if os.path.exists(kernels_aot_path): shutil.rmtree(kernels_aot_path) return if isinstance(kernels, str): kernels = [kernels] if not isinstance(kernels, list): raise TypeError(f'kernels_name must be a list of str, but got {type(kernels)}') # clear brainpy kernel cache for kernel_name in kernels: if os.path.exists(os.path.join(kernels_aot_path, kernel_name)): shutil.rmtree(os.path.join(kernels_aot_path, kernel_name))
# TODO # not a very good way # get source with dependencies def get_source_with_dependencies(func, visited=None): if visited is None: visited = set() source = inspect.getsource(func) if func in visited: return '' visited.add(func) module = inspect.getmodule(func) dependent_funcs = re.findall(r'(\w+)\(', source) for func_name in dependent_funcs: dependent_func = getattr(module, func_name, None) if callable(dependent_func): source += get_source_with_dependencies(dependent_func, visited) return source # check if Metal is supported def is_metal_supported(): # first check if we are on macOS if platform.system() != 'Darwin': return False if platform.processor() != 'arm': return False return True # --- VARIABLES ### home_path = get_home_dir() kernels_aot_path = os.path.join(home_path, '.brainpy', 'kernels') is_metal_device = is_metal_supported() # check if a kernel exists in the database def _check_kernel_exist(source_md5_encode: str) -> bool: # get the realpath of the kernel kernel_path = os.path.join(kernels_aot_path, source_md5_encode) # check whether the kernel exists if os.path.exists(kernel_path): return True else: return False # --- KERNEL AOT BUILD ### def _array_to_field(dtype, shape) -> Any: ti = import_taichi() if dtype == np.bool_: dtype = bool elif dtype == np.int8: dtype = ti.int8 elif dtype == np.int16: dtype = ti.int16 elif dtype == np.int32: dtype = ti.int32 elif dtype == np.int64: dtype = ti.int64 elif dtype == np.uint8: dtype = ti.uint8 elif dtype == np.uint16: dtype = ti.uint16 elif dtype == np.uint32: dtype = ti.uint32 elif dtype == np.uint64: dtype = ti.uint64 elif dtype == np.float16: dtype = ti.float16 elif dtype == np.float32: dtype = ti.float32 elif dtype == np.float64: dtype = ti.float64 else: raise NotImplementedError(f'Currently we do not support dtype {dtype} in Taichi. ' f'If you think it is necessary, please open an issue at ' f'') return ti.field(dtype=dtype, shape=shape) # build aot kernel def _build_kernel( source_md5_encode: str, kernel: callable, ins: dict, outs: dict, device: str ): ti = import_taichi() # init arch if device == 'cpu': if is_metal_device: arch = ti.arm64 device = 'arm64' else: arch = ti.x64 elif device == 'gpu': arch = ti.cuda else: raise ValueError(f'Unknown device: {device}') with contextlib.redirect_stdout(io.StringIO()): ti.init(arch=arch) # check arch is available if ti.lang.impl.current_cfg().arch != arch: raise RuntimeError(f"Arch {arch} is not available") # get kernel name kernel_name = kernel.__name__ # replace the name of the func kernel.__name__ = f'taichi_kernel_{device}' # init template_args_dict template_args_dict = {} for key, value in ins.items(): template_args_dict[key] = _array_to_field(value[0], value[1]) for key, value in outs.items(): template_args_dict[key] = _array_to_field(value[0], value[1]) # make aot dir kernel_path = os.path.join(kernels_aot_path, source_md5_encode) os.makedirs(kernel_path, exist_ok=True) # compile kernel mod = ti.aot.Module(arch) mod.add_kernel(kernel, template_args=template_args_dict) # rename kernel name kernel.__name__ = kernel_name # --- KERNEL CALL PREPROCESS ### # convert type to number type_number_map = { int: 0, float: 1, bool: 2, np.dtype('int32'): 0, np.dtype('float32'): 1, np.dtype('bool'): 2, np.dtype('uint8'): 3, np.dtype('uint16'): 4, np.dtype('uint32'): 5, np.dtype('uint64'): 6, np.dtype('int8'): 7, np.dtype('int16'): 8, np.dtype('int64'): 9, np.dtype('float16'): 10, np.dtype('float64'): 11, } # preprocess kernel call cpu def _preprocess_kernel_call_cpu( source_md5_encode: str, ins: Sequence, outs: Sequence, ) -> list: in_out_info = [] max_dim_count = 0 for value in ins: if value.ndim > max_dim_count: max_dim_count = value.ndim for value in outs: if value.ndim > max_dim_count: max_dim_count = value.ndim # kernel_path kernel_path = os.path.join(kernels_aot_path, source_md5_encode) kernel_path = bytes(kernel_path, encoding='utf-8') + b'\0' kernel_path = np.array(list(kernel_path), dtype=np.uint8) # other args in_out_num = np.array([len(ins), len(outs), kernel_path.size], dtype=np.uint32) in_out_type_list = np.zeros((len(ins) + len(outs),), dtype=np.uint32) in_out_dim_count_list = np.zeros((len(ins) + len(outs),), dtype=np.uint32) in_out_elem_count_list = np.zeros((len(ins) + len(outs),), dtype=np.uint32) in_out_shape_list = np.zeros((len(ins) + len(outs), max_dim_count), dtype=np.uint32) for i, value in enumerate(ins): in_out_type_list[i] = type_number_map[value.dtype] in_out_dim_count_list[i] = value.ndim in_out_elem_count_list[i] = value.size for j, dim in enumerate(value.shape): in_out_shape_list[i, j] = dim b = len(ins) for i, value in enumerate(outs): in_out_type_list[i + b] = type_number_map[value.dtype] in_out_dim_count_list[i + b] = value.ndim in_out_elem_count_list[i + b] = value.size for j, dim in enumerate(value.shape): in_out_shape_list[i + b, j] = dim in_out_info.append(in_out_num) in_out_info.append(in_out_type_list) in_out_info.append(in_out_dim_count_list) in_out_info.append(in_out_elem_count_list) in_out_info.append(in_out_shape_list) in_out_info.append(kernel_path) return in_out_info def _preprocess_kernel_call_gpu( source_md5_encode: str, ins: Sequence, outs: Sequence, ) -> bytes: # if len(ins) + len(outs) > 8: # raise ValueError('The number of ins and outs must be less than 8!') kernel_path = os.path.join(kernels_aot_path, source_md5_encode) # other args param_total_num = len(ins) + len(outs) in_out_num = [len(ins), len(outs)] in_out_type_list = [0] * param_total_num in_out_dim_count_list = [0] * param_total_num in_out_elem_count_list = [0] * param_total_num in_out_shape_list = [0] * param_total_num * 8 for i, value in enumerate(ins): in_out_type_list[i] = type_number_map[value.dtype] in_out_dim_count_list[i] = value.ndim in_out_elem_count_list[i] = value.size for j, dim in enumerate(value.shape): in_out_shape_list[i * 8 + j] = dim for i, value in enumerate(outs): in_out_type_list[i + len(ins)] = type_number_map[value.dtype] in_out_dim_count_list[i + len(ins)] = value.ndim in_out_elem_count_list[i + len(ins)] = value.size for j, dim in enumerate(value.shape): in_out_shape_list[(i + len(ins)) * 8 + j] = dim # covert to string in_out_num_str = ",".join(str(i) for i in in_out_num) in_out_type_list_str = ",".join(str(i) for i in in_out_type_list) in_out_dim_count_list_str = ",".join(str(i) for i in in_out_dim_count_list) in_out_elem_count_list_str = ",".join(str(i) for i in in_out_elem_count_list) in_out_shape_list_str = ",".join(str(i) for i in in_out_shape_list) opaque = (bytes(in_out_num_str, encoding='utf-8') + b';' + bytes(in_out_type_list_str, encoding='utf-8') + b';' + bytes(in_out_dim_count_list_str, encoding='utf-8') + b';' + bytes(in_out_elem_count_list_str, encoding='utf-8') + b';' + bytes(in_out_shape_list_str, encoding='utf-8') + b';' + bytes(kernel_path, encoding='utf-8')) return opaque def _XlaOp_to_ShapedArray(c, xla_op): xla_op = c.get_shape(xla_op) return jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type()) def _mlir_to_ShapedArray(c, op): return op def _kernel_to_code(kernel, abs_ins, abs_outs, platform): codes = f'[taichi {platform} kernel]\n' + get_source_with_dependencies(kernel) codes += '\n[ins]: {}'.format("-".join([f'{v.dtype}[{v.shape}]' for v in abs_ins])) codes += '\n[outs]: {}'.format("-".join([f'{v.dtype}[{v.shape}]' for v in abs_outs])) return codes def _compile_kernel(abs_ins, kernel, platform: str, **kwargs): # input and output abstract information abs_outs = kwargs['outs'] # kernel to code codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform) source_md5_encode = os.path.join(kernel.__name__, encode_md5(codes)) # create ins, outs dict from kernel's args in_num = len(abs_ins) names = tuple(inspect.signature(kernel).parameters.keys()) in_names, out_names = names[:in_num], names[in_num:] ins_dict = {key: (abs_ins[i].dtype, abs_ins[i].shape) for i, key in enumerate(in_names)} outs_dict = {key: (abs_outs[i].dtype, abs_outs[i].shape) for i, key in enumerate(out_names)} # build kernels if not _check_kernel_exist(source_md5_encode): # TODO: more checking try: _build_kernel(source_md5_encode, kernel, ins_dict, outs_dict, platform) except Exception as e: try: os.removedirs(os.path.join(kernels_aot_path, source_md5_encode)) except Exception: raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e raise RuntimeError(f'Failed to build kernel:\n\n {codes}') from e # returns if platform in ['gpu', 'cuda']: import_brainpylib_gpu_ops() opaque = _preprocess_kernel_call_gpu(source_md5_encode, abs_ins, abs_outs) return opaque elif platform == 'cpu': import_brainpylib_cpu_ops() in_out_info = _preprocess_kernel_call_cpu(source_md5_encode, abs_ins, abs_outs) return in_out_info else: raise ValueError(f'Unknown platform: {platform}') def _get_abs_ins(c, ins): abs_ins = [] for v in ins: xla_op = c.get_shape(v) abs_ins.append(jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type())) return abs_ins def _taichi_xla_cpu_translation_rule(kernel, c, *ins, **kwargs): in_out_info = _compile_kernel(_get_abs_ins(c, ins), kernel, 'cpu', **kwargs) ins = [xla_client.ops.Constant(c, v) for v in in_out_info] + list(ins) if is_metal_device: fn = b'taichi_kernel_aot_call_cpu_arm64' else: fn = b'taichi_kernel_aot_call_cpu' return xla_client.ops.CustomCallWithLayout( c, fn, operands=ins, operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), shape_with_layout=xla_client.Shape.tuple_shape( [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape)) for value in kwargs['outs']] ), ) def _taichi_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): opaque = _compile_kernel(_get_abs_ins(c, ins), kernel, 'gpu', **kwargs) return xla_client.ops.CustomCallWithLayout( c, b'taichi_kernel_aot_call_gpu', operands=ins, operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), shape_with_layout=xla_client.Shape.tuple_shape( [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape)) for value in kwargs['outs']] ), opaque=opaque, ) def register_taichi_aot_xla_cpu_translation_rule(primitive, cpu_kernel): xla.backend_specific_translations['cpu'][primitive] = partial(_taichi_xla_cpu_translation_rule, cpu_kernel) def register_taichi_aot_xla_gpu_translation_rule(primitive, gpu_kernel): xla.backend_specific_translations['gpu'][primitive] = partial(_taichi_xla_gpu_translation_rule, gpu_kernel) def _taichi_mlir_cpu_translation_rule(kernel, c, *ins, **kwargs): in_out_info = _compile_kernel(c.avals_in, kernel, 'cpu', **kwargs) ins = [mlir.ir_constant(v) for v in in_out_info] + list(ins) input_layouts = [_shape_to_layout(arr.shape) for arr in in_out_info] + [_shape_to_layout(a.shape) for a in c.avals_in] output_layouts = tuple([_shape_to_layout(out.shape) for out in c.avals_out]) result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] if is_metal_device: if len(output_layouts) == 1: fn = 'taichi_kernel_aot_call_cpu_arm64_single_result' else: fn = 'taichi_kernel_aot_call_cpu_arm64' else: if len(output_layouts) == 1: fn = 'taichi_kernel_aot_call_cpu_single_result' else: fn = 'taichi_kernel_aot_call_cpu' return custom_call( call_target_name=fn, operands=ins, operand_layouts=list(input_layouts), result_layouts=list(output_layouts), result_types=list(result_types), has_side_effect=False, ).results def _taichi_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): opaque = _compile_kernel(c.avals_in, kernel, 'gpu', **kwargs) input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in] result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] output_layouts = [_shape_to_layout(out.shape) for out in c.avals_out] return custom_call( call_target_name='taichi_kernel_aot_call_gpu', operands=ins, operand_layouts=list(input_layouts), result_layouts=list(output_layouts), result_types=list(result_types), backend_config=opaque, has_side_effect=False, ).results def register_taichi_aot_mlir_cpu_translation_rule(primitive, cpu_kernel): if import_taichi(error_if_not_found=False) is None: raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule') rule = partial(_taichi_mlir_cpu_translation_rule, cpu_kernel) mlir.register_lowering(primitive, rule, platform='cpu') def register_taichi_aot_mlir_gpu_translation_rule(primitive, gpu_kernel): if import_taichi(error_if_not_found=False) is None: raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule') rule = partial(_taichi_mlir_gpu_translation_rule, gpu_kernel) mlir.register_lowering(primitive, rule, platform='gpu')