View source on GitHub |
Implements dp_query for accurate range queries using tree aggregation.
Inherits From: SumAggregationDPQuery
, DPQuery
tf_privacy.TreeRangeSumQuery(
inner_query: tf_privacy.SumAggregationDPQuery
,
arity: int = 2
)
Implements a variant of the tree aggregation protocol from. "Is interaction necessary for distributed private learning?. Adam Smith, Abhradeep Thakurta, Jalaj Upadhyay." Builds a tree on top of the input record and adds noise to the tree for differential privacy. Any range query can be decomposed into the sum of O(log(n)) nodes in the tree compared to O(n) when using a histogram. Improves efficiency and reduces noise scale.
Args | |
---|---|
inner_query
|
The inner DPQuery that adds noise to the tree.
|
arity
|
The branching factor of the tree (i.e. the number of children each internal node has). Defaults to 2. |
Child Classes
Methods
accumulate_preprocessed_record
accumulate_preprocessed_record(
sample_state, preprocessed_record
)
Implements tensorflow_privacy.DPQuery.accumulate_preprocessed_record
.
accumulate_record
accumulate_record(
params, sample_state, record
)
Accumulates a single record into the sample state.
This is a helper method that simply delegates to preprocess_record
and
accumulate_preprocessed_record
for the common case when both of those
functions run on a single device. Typically this will be a simple sum.
Args | |
---|---|
params
|
The parameters for the sample. In standard DP-SGD training, the clipping norm for the sample's microbatch gradients (i.e., a maximum norm magnitude to which each gradient is clipped) |
sample_state
|
The current sample state. In standard DP-SGD training, the accumulated sum of previous clipped microbatch gradients. |
record
|
The record to accumulate. In standard DP-SGD training, the gradient computed for the examples in one microbatch, which may be the gradient for just one example (for size 1 microbatches). |
Returns | |
---|---|
The updated sample state. In standard DP-SGD training, the set of previous microbatch gradients with the addition of the record argument. |
build_central_gaussian_query
@classmethod
build_central_gaussian_query( l2_norm_clip: float, stddev: float, arity: int = 2 )
Returns TreeRangeSumQuery
with central Gaussian noise.
Args | |
---|---|
l2_norm_clip
|
Each record should be clipped so that it has L2 norm at most
l2_norm_clip .
|
stddev
|
Stddev of the central Gaussian noise. |
arity
|
The branching factor of the tree (i.e. the number of children each internal node has). Defaults to 2. |
build_distributed_discrete_gaussian_query
@classmethod
build_distributed_discrete_gaussian_query( l2_norm_bound: float, local_stddev: float, arity: int = 2 )
Returns TreeRangeSumQuery
with central Gaussian noise.
Args | |
---|---|
l2_norm_bound
|
Each record should be clipped so that it has L2 norm at
most l2_norm_bound .
|
local_stddev
|
Scale/stddev of the local discrete Gaussian noise. |
arity
|
The branching factor of the tree (i.e. the number of children each internal node has). Defaults to 2. |
derive_metrics
derive_metrics(
global_state
)
Derives metric information from the current global state.
Any metrics returned should be derived only from privatized quantities.
Args | |
---|---|
global_state
|
The global state from which to derive metrics. |
Returns | |
---|---|
A collections.OrderedDict mapping string metric names to tensor values.
|
derive_sample_params
derive_sample_params(
global_state
)
Implements tensorflow_privacy.DPQuery.derive_sample_params
.
get_noised_result
get_noised_result(
sample_state, global_state
)
Implements tensorflow_privacy.DPQuery.get_noised_result
.
This function re-constructs the tf.RaggedTensor
from the flattened tree
output by preprocess_records.
Args | |
---|---|
sample_state
|
A tf.Tensor for the flattened tree.
|
global_state
|
The global state of the protocol. |
Returns | |
---|---|
A tf.RaggedTensor representing the tree.
|
initial_global_state
initial_global_state()
Implements tensorflow_privacy.DPQuery.initial_global_state
.
initial_sample_state
initial_sample_state(
template=None
)
Implements tensorflow_privacy.DPQuery.initial_sample_state
.
merge_sample_states
merge_sample_states(
sample_state_1, sample_state_2
)
Implements tensorflow_privacy.DPQuery.merge_sample_states
.
preprocess_record
preprocess_record(
params, record
)
Implements tensorflow_privacy.DPQuery.preprocess_record
.
This method builds the tree, flattens it and applies
inner_query.preprocess_record
to the flattened tree.
Args | |
---|---|
params
|
Hyper-parameters for preprocessing record. |
record
|
A histogram representing the leaf nodes of the tree. |
Returns | |
---|---|
A tf.Tensor representing the flattened version of the preprocessed tree.
|