"""
|
This generates .pyi stubs for the cffi Python bindings generated by regenerate.py
|
"""
|
import sys, re, itertools
|
sys.path.extend(['.', '..']) # for pycparser
|
|
from pycparser import c_ast, parse_file, CParser
|
import pycparser.plyparser
|
from pycparser.c_ast import PtrDecl, TypeDecl, FuncDecl, EllipsisParam, IdentifierType, Struct, Enum, Typedef
|
from typing import Tuple
|
|
__c_type_to_python_type = {
|
'void': 'None', '_Bool': 'bool',
|
'char': 'int', 'short': 'int', 'int': 'int', 'long': 'int',
|
'ptrdiff_t': 'int', 'size_t': 'int',
|
'int8_t': 'int', 'uint8_t': 'int',
|
'int16_t': 'int', 'uint16_t': 'int',
|
'int32_t': 'int', 'uint32_t': 'int',
|
'int64_t': 'int', 'uint64_t': 'int',
|
'float': 'float', 'double': 'float',
|
'ggml_fp16_t': 'np.float16',
|
}
|
|
def format_type(t: TypeDecl):
|
if isinstance(t, PtrDecl) or isinstance(t, Struct):
|
return 'ffi.CData'
|
if isinstance(t, Enum):
|
return 'int'
|
if isinstance(t, TypeDecl):
|
return format_type(t.type)
|
if isinstance(t, IdentifierType):
|
assert len(t.names) == 1, f'Expected a single name, got {t.names}'
|
return __c_type_to_python_type.get(t.names[0]) or 'ffi.CData'
|
return t.name
|
|
class PythonStubFuncDeclVisitor(c_ast.NodeVisitor):
|
def __init__(self):
|
self.sigs = {}
|
self.sources = {}
|
|
def get_source_snippet_lines(self, coord: pycparser.plyparser.Coord) -> Tuple[list[str], list[str]]:
|
if coord.file not in self.sources:
|
with open(coord.file, 'rt') as f:
|
self.sources[coord.file] = f.readlines()
|
source_lines = self.sources[coord.file]
|
ncomment_lines = len(list(itertools.takewhile(lambda i: re.search(r'^\s*(//|/\*)', source_lines[i]), range(coord.line - 2, -1, -1))))
|
comment_lines = [l.strip() for l in source_lines[coord.line - 1 - ncomment_lines:coord.line - 1]]
|
decl_lines = []
|
for line in source_lines[coord.line - 1:]:
|
decl_lines.append(line.rstrip())
|
if (';' in line) or ('{' in line): break
|
return (comment_lines, decl_lines)
|
|
def visit_Enum(self, node: Enum):
|
if node.values is not None:
|
for e in node.values.enumerators:
|
self.sigs[e.name] = f' @property\n def {e.name}(self) -> int: ...'
|
|
def visit_Typedef(self, node: Typedef):
|
pass
|
|
def visit_FuncDecl(self, node: FuncDecl):
|
ret_type = node.type
|
is_ptr = False
|
while isinstance(ret_type, PtrDecl):
|
ret_type = ret_type.type
|
is_ptr = True
|
|
fun_name = ret_type.declname
|
if fun_name.startswith('__'):
|
return
|
|
args = []
|
argnames = []
|
def gen_name(stem):
|
i = 1
|
while True:
|
new_name = stem if i == 1 else f'{stem}{i}'
|
if new_name not in argnames: return new_name
|
i += 1
|
|
for a in node.args.params:
|
if isinstance(a, EllipsisParam):
|
arg_name = gen_name('args')
|
argnames.append(arg_name)
|
args.append('*' + gen_name('args'))
|
elif format_type(a.type) == 'None':
|
continue
|
else:
|
arg_name = a.name or gen_name('arg')
|
argnames.append(arg_name)
|
args.append(f'{arg_name}: {format_type(a.type)}')
|
|
ret = format_type(ret_type if not is_ptr else node.type)
|
|
comment_lines, decl_lines = self.get_source_snippet_lines(node.coord)
|
|
lines = [f' def {fun_name}({", ".join(args)}) -> {ret}:']
|
if len(comment_lines) == 0 and len(decl_lines) == 1:
|
lines += [f' """{decl_lines[0]}"""']
|
else:
|
lines += [' """']
|
lines += [f' {c.lstrip("/* ")}' for c in comment_lines]
|
if len(comment_lines) > 0:
|
lines += ['']
|
lines += [f' {d}' for d in decl_lines]
|
lines += [' """']
|
lines += [' ...']
|
self.sigs[fun_name] = '\n'.join(lines)
|
|
def generate_stubs(header: str):
|
"""
|
Generates a .pyi Python stub file for the GGML API using C header files.
|
"""
|
|
v = PythonStubFuncDeclVisitor()
|
v.visit(CParser().parse(header, "<input>"))
|
|
keys = list(v.sigs.keys())
|
keys.sort()
|
|
return '\n'.join([
|
'# auto-generated file',
|
'import ggml.ffi as ffi',
|
'import numpy as np',
|
'class lib:',
|
*[v.sigs[k] for k in keys]
|
])
|