View source on GitHub
|
Class outlining the default Tracing Protocol for Scalarizer.
tf_agents.bandits.multi_objective.multi_objective_scalarizer.ScalarizerTraceType(
value
)
If included as an argument, corresponding tf.function will always retrace for each usage.
Derived classes can override this behavior by specifying their own Tracing Protocol.
Methods
cast
cast(
value, cast_context
) -> Any
Cast value to this type.
| Args | |
|---|---|
value
|
An input value belonging to this TraceType. |
cast_context
|
A context reserved for internal/future usage. |
| Returns | |
|---|---|
| The value casted to this TraceType. |
| Raises | |
|---|---|
AssertionError
|
When _cast is not overloaded in subclass, the value is returned directly, and it should be the same to self.placeholder_value(). |
flatten
flatten() -> List['TraceType']
Returns a list of TensorSpecs corresponding to to_tensors values.
from_tensors
from_tensors(
tensors: Iterator[core.Tensor]
) -> Any
Generates a value of this type from Tensors.
Must use the same fixed amount of tensors as to_tensors.
| Args | |
|---|---|
tensors
|
An iterator from which the tensors can be pulled. |
| Returns | |
|---|---|
| A value of this type. |
is_subtype_of
is_subtype_of(
_
)
Returns True if self is a subtype of other.
For example, tf.function uses subtyping for dispatch:
if a.is_subtype_of(b) is True, then an argument of TraceType
a can be used as argument to a ConcreteFunction traced with an
a TraceType b.
| Args | |
|---|---|
other
|
A TraceType object to be compared against. |
Example:
class Dimension(TraceType):
def __init__(self, value: Optional[int]):
self.value = value
def is_subtype_of(self, other):
# Either the value is the same or other has a generalized value that
# can represent any specific ones.
return (self.value == other.value) or (other.value is None)
most_specific_common_supertype
most_specific_common_supertype(
_
)
Returns the most specific supertype of self and others, if exists.
The returned TraceType is a supertype of self and others, that is,
they are all subtypes (see is_subtype_of) of it.
It is also most specific, that is, there it has no subtype that is also
a common supertype of self and others.
If self and others have no common supertype, this returns None.
| Args | |
|---|---|
others
|
A sequence of TraceTypes. |
Example:
class Dimension(TraceType):
def __init__(self, value: Optional[int]):
self.value = value
def most_specific_common_supertype(self, other):
# Either the value is the same or other has a generalized value that
# can represent any specific ones.
if self.value == other.value:
return self.value
else:
return Dimension(None)
placeholder_value
placeholder_value(
placeholder_context=None
)
Creates a placeholder for tracing.
tf.funcion traces with the placeholder value rather than the actual value. For example, a placeholder value can represent multiple different actual values. This means that the trace generated with that placeholder value is more general and reusable which saves expensive retracing.
| Args | |
|---|---|
placeholder_context
|
A context reserved for internal/future usage. |
For the Fruit example shared above, implementing:
class FruitTraceType:
def placeholder_value(self, placeholder_context):
return Fruit()
instructs tf.function to trace with the Fruit() objects
instead of the actual Apple() and Mango() objects when it receives a
call to get_mixed_flavor(Apple(), Mango()). For example, Tensor arguments
are replaced with Tensors of similar shape and dtype, output from
a tf.Placeholder op.
More generally, placeholder values are the arguments of a tf.function, as seen from the function's body:
@tf.function
def foo(x):
# Here `x` is be the placeholder value
...
foo(x) # Here `x` is the actual value
to_tensors
to_tensors(
value: Any
) -> List[core.Tensor]
Breaks down a value of this type into Tensors.
For a TraceType instance, the number of tensors generated for corresponding value should be constant.
| Args | |
|---|---|
value
|
A value belonging to this TraceType |
| Returns | |
|---|---|
| List of Tensors. |
__eq__
__eq__(
_
)
Return self==value.
View source on GitHub