tfp.experimental.auto_batching.dsl.ProgramBuilder

An auto-batching DSL context.

Auto-batching DSL operations are methods on the ProgramBuilder object. It's used like this:

ab = dsl.ProgramBuilder()

def fib_type(arg_types):
  return arg_types[0]

with ab.function(type_inference=fib_type) as fibonacci:
  n = ab.param('n')
  ab.var.cond = ab.primop(lambda n: n > 1)
  with ab.if_(ab.var.cond):
    ab.var.nm1 = ab.primop(lambda n: n - 1)
    ab.var.fibm1 = ab.call(fibonacci, [ab.var.nm1])
    ab.var.nm2 = ab.primop(lambda n: n - 2)
    ab.var.fibm2 = ab.call(fibonacci, [ab.var.nm2])
    ab.var.ans = ab.primop(lambda fibm1, fibm2: fibm1 + fibm2)
  with ab.else_():
    ab.var.ans = ab.const(1)
  ab.return_(ab.var.ans)

prog = ab.program(main=fibonacci)
# Now `prog` is a well-formed `instructions.Program`, and the context
# `ab` is no longer needed.

Note the sequence of method calls on ProgramBuilder corresponds to the source code of the Program being defined, not its runtime behavior. This is because (a) functions are defined with a context manager (rather than a Python function) which executes its body immediately and exactly once; and (b) function call instructions (and primitive operations) are just recorded, not entered recursively.

var Auto-batching variables visible in the current scope.

Overrides setattr and getattr to provide a smooth interface to reading and defining variables:

  • ProgramBuilder.var.foo = ProgramBuilder.{call,primop} records an assignment to the auto-batched variable foo, possibly binding it, and

  • ProgramBuilder.var.foo reads from the auto-batched variable foo (if it is bound).

ab = dsl.ProgramBuilder()

ab.var.seven = ab.const(7)

Methods

call

View source

Registers a function call instruction.

Example:

ab = dsl.ProgramBuilder()

# Define a function
with ab.function(...) as func:
  ...
  # Call it (recursively)
  ab.var.thing = ab.call(func, ...)
  ...

Args
function The instructions.Function object representing the function to call.
vars_in Python strings giving the variables to pass in as inputs.
vars_out A pattern of Python strings, giving the auto-batched variable(s) to which to write the result of the call. Defaults to the empty list.

Raises
ValueError If the call references undefined auto-batched variables.

Returns
op An instructions.FunctionCallOp representing the call. If one subsequently assigns this to a local, via ProgramBuilder.var.foo = op, that local gets added to the list of output variables.

const

View source

Records a constant or set of constants.

Like primop, the output variables can be specified explicitly via the vars_out argument or implicitly by assigning the return value to some ProgramBuilder.var.foo.

Args
value A Python list of the constants to record.
vars_out A pattern of Python strings, giving the auto-batched variable(s) to which to write the result of the callable. Defaults to the empty list.

Returns
op An instructions.PrimOp instance representing this operation. If one subsequently assigns this to a local, via ProgramBuilder.var.foo = op, that local gets added to the list of output variables.

declare_function

View source

Forward-declares a function to be defined later with define_function.

This useful for defining mutually recursive functions:

ab = dsl.ProgramBuilder()

foo = ab.declare_function(...)

with ab.function(...) as bar:
  ...
  ab.call(foo)

with ab.define_function(foo):
  ...
  ab.call(bar)

It is an error to call but never define a declared function.

Args
name Optional string naming this function when the program is printed.
type_inference A Python callable giving the type signature of the function being defined. See function.

Returns
function An instructions.Function object representing the function being declared. It can be passed to call to call it, and to define_function to define it.

define_function

View source

Registers a definition for a previously declared function.

Usually, one would use the function method to declare and define a function at the same time. Explicit use of define_function is only useful for mutual recursion or controlling code order separately from the call graph.

Example:

ab = dsl.ProgramBuilder()

foo = ab.declare_function(...)

with ab.function(...) as bar:
  ...
  ab.call(foo)

with ab.define_function(foo):
  ...
  ab.call(bar)

Function bodies appear in the compiled instructions.Program in order of definition, not declaration.

Note:

  • The formal parameters are given by calling ab.param inside the with block.
  • The body of the with block registers the body of the function being defined.
  • The last statement registered in the with block must be a ab.return_, or the Function will be malformed.

Args
function The function (from declare_function) to define.

Yields
function The function being defined, by symmetry with the context.function method.

Raises
ValueError If invoked while defining a function, if the function argument has already been defined, or if the function definition does not end in a return_.

else_

View source

Records the false branch of a conditional operation.

The true branch must be recorded (by if_, above) as the immediately preceding operation at the same nesting depth.

Example:

ab = dsl.ProgramBuilder()

ab.var.false = ab.const(False)
with ab.if_(ab.var.false):
  ...
with ab.else_():
  ...  # The body of the `with` statement gives the `false` branch

Args
else_name Optional Python string naming the false branch when the program is printed. Overrides the else_name, if any, given in the corresponding if_.
continue_name Optional Python string naming the continuation after the if when the program is printed. Overrides the continue_name, if any, given in the corresponding if_.

Raises
ValueError If not immediately preceded by an if_.

Yields
Nothing.

function

View source

Registers a definition of an auto-batchable function.

Example:

ab = dsl.ProgramBuilder()

with ab.function(...) as f:
  ab.param('n')
  ...
  ab.return_(...)

Note:

  • The as clause (here f) binds an instructions.Function object representing the function being defined (see Yields).
  • The formal parameters are given by calling param inside the with block.
  • The body of the with block registers the body of the function being defined.
  • The last statement registered in the with block must be a call to return_, or the Function will be malformed.

The function method is a shorthand of declare_function followed by define_function. The example is equivalent to:

ab = dsl.ProgramBuilder()

f = ab.declare_function(...)
with ab.define_function(f):
  ab.param('n')
  ...
  ab.return_(...)

Args
name Optional string naming this function when the program is printed.
type_inference A Python callable giving the type signature of the function being defined. The callable will be invoked with a single argument giving the list of instruction.Type objects describing the arguments at a particular call site, and must return a list of instruction.Type objects describing the values that call site will return.

Raises
ValueError If invoked while defining a function, or if the function definition does not end in a return_.

Yields
function An instructions.Function object representing the function being defined. It can be passed to call to call it (including recursively). Note that Python scopes as bindings to the definition enclosing the with, so a function thus bound can be referred to after its body as well.

if_

View source

Records a conditional operation and true first branch.

The false branch, if present, must be guarded by a call to else_, below.

Example:

ab = dsl.ProgramBuilder()

ab.var.true = ab.const(True)
with ab.if_(ab.var.true):
  ...  # The body of the `with` statement gives the `true` branch
with ab.else_():  # The else_ clause is optional
  ...

Args
condition Python string giving the boolean variable that holds the branch condition.
then_name Optional Python string naming the true branch when the program is printed.
else_name Optional Python string naming the false branch when the program is printed.
continue_name Optional Python string naming the continuation after the if when the program is printed.

Yields
Nothing.

Raises
ValueError If trying to condition on a variable that has not been written to.

local

View source

Declares a local variable in the current scope.

This should typically not be needed, because ProgramBuilder.var.foo = can bind variables; however, may be helpful for a multivalue return (see primop or call).

Args
name Optional Python string to serve a mnemonic name in later compiler stages. Variable names are automatically uniqued. This variable can later be referred to with ProgramBuilder.var.name, as well as through any Python binding of the returned value.
define Boolean giving whether to mark this variable defined on creation. Default True. Setting False is useful for speculatively uniquing a variable on its first appearance, before knowning whether said appearance is a write (in which case the variable becomes defined) or a read (which raises an error).

Returns
var A Python string representing this variable. Suitable for passing to primop, call, if_, and return_.

locals_

View source

Declares several variables at once.

This is a convenience method standing for several invocations of local.

Args
count Python int. The number of distinct variables to return.
name Optional Python string to serve a mnemonic name in later compiler stages. Variable names are automatically uniqued.

Returns
vars A list of count Python strings representing these variables. Suitable for passing to primop, call, if_, and return_.

module

View source

Returns the registered function definitions as an instructions.Module.

Example:

ab = dsl.ProgramBuilder()

with ab.function(...) as foo:
  ...  # Do stuff

module = ab.module()

Raises
ValueError If invoked inside a function definition.

Returns
module The instructions.Module corresponding to all the definitions accumulated in this ProgramBuilder.

param

View source

Declares a parameter of the function currently being defined.

This make a local variable like local, but also makes it an input of the nearest enclosing function (created by with ProgramBuilder.function()). This is a separate method from function because the DSL wants to create Python bindings for the function name itself and all of its input parameters, and there is no way to convince the with syntax to do that.

Args
name Optional Python string to serve a mnemonic name in later compiler stages. Variable names are automatically uniqued.

Returns
var A Python string representing this variable. Suitable for passing to primop, call, if_, and return_.

primop

View source

Records a primitive operation.

Example:

ab = dsl.ProgramBuilder()

ab.var.five = ab.const(5)
# Implicit output binding
ab.var.ten = ab.primop(lambda five: five + five)
# Explicit output binding
ab.primop(lambda: (5, 10), vars_out=[ab.var.five, ab.var.ten])

Args
f A Python callable, the primitive operation to perform. Can be an inline lambda expression in simple cases. Must return a list or tuple of results, one for each intended output variable.
vars_in A list of Python strings, giving the auto-batched variables to pass into the callable when invoking it. If absent, primop will try to infer it by inspecting the argument list of the callable and matching against variables bound in the local scope.
vars_out A pattern of Python strings, giving the auto-batched variable(s) to which to write the result of the callable. Defaults to the empty list.

Raises
ValueError If the definition is invalid, if the primop references undefined auto-batched variables, or if auto-detection of input variables fails.

Returns
op An instructions.PrimOp instance representing this operation. If one subsequently assigns this to a local, via ProgramBuilder.var.foo = op, that local becomes the output pattern.

program

View source

Returns the registered program as an instructions.Program.

This is a helper method, equivalent to self.module().program(main).

Example:

ab = dsl.ProgramBuilder()

with ab.function(...) as main:
  ...  # Do the stuff

program = ab.program(main)

Args
main An instructions.Function object representing the main entry point.

Raises
ValueError If invoked inside a function definition, of if the intended main function was not defined.

Returns
program The instructions.Program corresponding to all the definitions accumulated in this ProgramBuilder.

return_

View source

Records a function return instruction.

Example:

ab = dsl.ProgramBuilder()

with ab.function(...) as f:
  ...
  ab.var.result = ...
  ab.return_(ab.var.result)

A return_ command must occur at the top level of the function definition (not inside any if_s), and must be the last statement therein. You can always achieve this by assigning to a dedicated variable for the answer where you would otherwise return (and massaging your control flow).

Args
vars_out Pattern of Python strings giving the auto-batched variables to return.

Raises
ValueError If invoked more than once in a function body, or if trying to return variables that have not been written to.

__call__

View source

Prepares a multi-value return.

Example:

ab = dsl.ProgramBuilder()

ab((ab.var.two, ab.var.four)).pattern = ab.const((2, 4))

The protocol is to create a magic pattern object by invoking the ProgramBuilder as a callable, passing the pattern to bind; then assigning the pattern attribute of the returned value to the operation whose values to accept.

This is like this to work around limitations of embedding a DSL into Python: the assignment syntax = can be overridden only for fields of objects, not for function calls. It would have been nicer to implement ab.pattern(...) = ... but that's syntactically invalid Python. Hence, putting the pattern token at the end of the phrase rather than the beginning.

Args
pattern A pattern of variables (e.g., from ab.var.name) to bind.

Returns
pat_object A _MagicPattern instance representing the putative binding. Invoke the pattern = attribute setter on that instance to actually bind this pattern as the output of a primop, const, or call.