View source on GitHub |
Holds a Tensor which a tf.function can capture.
tf.saved_model.experimental.TrackableResource(
device=''
)
A TrackableResource is most useful for stateful Tensors that require
initialization, such as tf.lookup.StaticHashTable
. TrackableResource
s
are discovered by traversing the graph of object attributes, e.g. during
tf.saved_model.save
.
A TrackableResource has three methods to override:
_create_resource
should create the resource tensor handle._initialize
should initialize the resource held atself.resource_handle
._destroy_resource
is called upon aTrackableResource
's destruction and should decrement the resource's ref count. For most resources, this should be done with a call totf.raw_ops.DestroyResourceOp
.
Example usage:
class DemoResource(tf.saved_model.experimental.TrackableResource):
def __init__(self):
super().__init__()
self._initialize()
def _create_resource(self):
return tf.raw_ops.VarHandleOp(dtype=tf.float32, shape=[2])
def _initialize(self):
tf.raw_ops.AssignVariableOp(
resource=self.resource_handle, value=tf.ones([2]))
def _destroy_resource(self):
tf.raw_ops.DestroyResourceOp(resource=self.resource_handle)
class DemoModule(tf.Module):
def __init__(self):
self.resource = DemoResource()
def increment(self, tensor):
return tensor + tf.raw_ops.ReadVariableOp(
resource=self.resource.resource_handle, dtype=tf.float32)
demo = DemoModule()
demo.increment([5, 1])
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([6., 2.], dtype=float32)>
Attributes | |
---|---|
resource_handle
|
Returns the resource handle associated with this Resource. |