# -*- coding: utf-8 -*-
import inspect
import re
from types import LambdaType
__all__ = [
'copy_doc',
'code_lines_to_func',
# tools for code string
'get_identifiers',
'indent',
'deindent',
'word_replace',
# other tools
'is_lambda_function',
'get_main_code',
'get_func_source',
'change_func_name',
]
[docs]def copy_doc(source_f):
def copy(target_f):
target_f.__doc__ = source_f.__doc__
return target_f
return copy
[docs]def code_lines_to_func(lines, func_name, func_args, scope, remind=''):
lines_for_compile = [f' {line}' for line in lines]
code_for_compile = '\n'.join(lines_for_compile)
code = f'def {func_name}({", ".join(func_args)}):\n' + \
f' try:\n' + \
f'{code_for_compile}\n' + \
f' except Exception as e:\n'
code += ' exc_type, exc_obj, exc_tb = sys.exc_info()\n'
code += ' line_no = exc_tb.tb_lineno\n'
code += ' raise ValueError(f"Error occurred in line {line_no}: {code_for_debug} {str(e)} {remind}")'
lines_for_debug = [f'[{i + 1:3d}] {line}' for i, line in enumerate(code.split('\n'))]
code_for_debug = '\n'.join(lines_for_debug)
scope['code_for_debug'] = '\n\n' + code_for_debug + '\n\n'
scope['remind'] = '\n' + remind + '\n'
try:
exec(compile(code, '', 'exec'), scope)
except Exception as e:
raise ValueError(f'Compilation function error: \n\n{code}') from e
func = scope[func_name]
return code, func
######################################
# String tools
######################################
[docs]def get_identifiers(expr, include_numbers=False):
"""
Return all the identifiers in a given string ``expr``, that is everything
that matches a programming language variable like expression, which is
here implemented as the regexp ``\\b[A-Za-z_][A-Za-z0-9_]*\\b``.
Parameters
----------
expr : str
The string to analyze
include_numbers : bool, optional
Whether to include number literals in the output. Defaults to ``False``.
Returns
-------
identifiers : set
A set of all the identifiers (and, optionally, numbers) in `expr`.
Examples
--------
>>> expr = '3-a*_b+c5+8+f(A - .3e-10, tau_2)*17'
>>> ids = get_identifiers(expr)
>>> print(sorted(list(ids)))
['A', '_b', 'a', 'c5', 'f', 'tau_2']
>>> ids = get_identifiers(expr, include_numbers=True)
>>> print(sorted(list(ids)))
['.3e-10', '17', '3', '8', 'A', '_b', 'a', 'c5', 'f', 'tau_2']
"""
_ID_KEYWORDS = {'and', 'or', 'not', 'True', 'False'}
identifiers = set(re.findall(r'\b[A-Za-z_][A-Za-z0-9_.]*\b', expr))
# identifiers = set(re.findall(r'\b[A-Za-z_][.?[A-Za-z0-9_]*]*\b', expr))
if include_numbers:
# only the number, not a + or -
pattern = r'(?<=[^A-Za-z_])[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?|^[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?'
numbers = set(re.findall(pattern, expr))
else:
numbers = set()
return (identifiers - _ID_KEYWORDS) | numbers
[docs]def indent(text, num_tabs=1, spaces_per_tab=4, tab=None):
if tab is None:
tab = ' ' * spaces_per_tab
indent_ = tab * num_tabs
indented_string = indent_ + text.replace('\n', '\n' + indent_)
return indented_string
[docs]def deindent(text, num_tabs=None, spaces_per_tab=4, docstring=False):
text = text.replace('\t', ' ' * spaces_per_tab)
lines = text.split('\n')
# if it's a docstring, we search for the common tabulation starting from
# line 1, otherwise we use all lines
if docstring:
start = 1
else:
start = 0
if docstring and len(lines) < 2: # nothing to do
return text
# Find the minimum indentation level
if num_tabs is not None:
indent_level = num_tabs * spaces_per_tab
else:
line_seq = [len(line) - len(line.lstrip()) for line in lines[start:] if len(line.strip())]
if len(line_seq) == 0:
indent_level = 0
else:
indent_level = min(line_seq)
# remove the common indentation
lines[start:] = [line[indent_level:] for line in lines[start:]]
return '\n'.join(lines)
[docs]def word_replace(expr, substitutions, exclude_dot=True):
"""Applies a dict of word substitutions.
The dict ``substitutions`` consists of pairs ``(word, rep)`` where each
word ``word`` appearing in ``expr`` is replaced by ``rep``. Here a 'word'
means anything matching the regexp ``\\bword\\b``.
Examples
--------
>>> expr = 'a*_b+c5+8+f(A)'
>>> print(word_replace(expr, {'a':'banana', 'f':'func'}))
banana*_b+c5+8+func(A)
"""
for var, replace_var in substitutions.items():
if exclude_dot:
expr = re.sub(r'\b(?<!\.)' + var + r'\b(?!\.)', str(replace_var), expr)
else:
expr = re.sub(r'\b' + var + r'\b', str(replace_var), expr)
return expr
######################################
# Other tools
######################################
[docs]def change_func_name(f, name):
f.__name__ = name
return f
[docs]def is_lambda_function(func):
"""Check whether the function is a ``lambda`` function. Comes from
https://stackoverflow.com/questions/23852423/how-to-check-that-variable-is-a-lambda-function
Parameters
----------
func : callable function
The function.
Returns
-------
bool
True of False.
"""
return isinstance(func, LambdaType) and func.__name__ == "<lambda>"
[docs]def get_func_source(func):
code = inspect.getsource(func)
# remove @
try:
start = code.index('def ')
code = code[start:]
except ValueError:
pass
return code
[docs]def get_main_code(func, codes=None):
"""Get the main function _code string.
For lambda function, return the
Parameters
----------
func : callable, Optional, int, float
Returns
-------
"""
if func is None:
return ''
elif callable(func):
if is_lambda_function(func):
codes = (codes or get_func_source(func))
splits = codes.split(':')
if len(splits) != 2:
raise ValueError(f'Can not parse function: \n{codes}')
return f'return {splits[1]}'
else:
codes = (codes.split('\n') or inspect.getsourcelines(func)[0])
idx = 0
for line in codes:
idx += 1
line = line.replace(' ', '')
if '):' in line:
break
else:
code = "\n".join(codes)
raise ValueError(f'Can not parse function: \n{code}')
return ''.join(codes[idx:])
else:
raise ValueError(f'Unknown function type: {type(func)}.')