Se você quiser criar uma operação que não seja coberta pela biblioteca existente do TensorFlow, recomendamos que primeiro tente escrever a operação em Python como uma composição de operações ou funções Python existentes. Se isso não for possível, você pode criar uma operação C++ personalizada. Há vários motivos pelos quais você pode querer criar uma operação C++ personalizada:
- Não é fácil nem possível expressar sua operação como uma composição de operações existentes.
- Não é eficiente expressar sua operação como uma composição de primitivas existentes.
- Você deseja fundir manualmente uma composição de primitivos que um futuro compilador acharia difícil fundir.
Por exemplo, imagine que você deseja implementar algo como "agrupamento de mediana", semelhante ao operador "MaxPool", mas calculando medianas em janelas deslizantes em vez de valores máximos. Fazer isso usando uma composição de operações pode ser possível (por exemplo, usando ExtractImagePatches e TopK), mas pode não ser tão eficiente em termos de desempenho ou memória quanto uma operação nativa onde você pode fazer algo mais inteligente em uma única operação fundida. Como sempre, normalmente vale a pena primeiro tentar expressar o que você deseja usando a composição de operadores, optando apenas por adicionar uma nova operação se isso for difícil ou ineficiente.
Para incorporar sua operação personalizada, você precisará:
- Registre a nova operação em um arquivo C++. O registro de operação define uma interface (especificação) para a funcionalidade da operação, que é independente da implementação da operação. Por exemplo, o registro da operação define o nome da operação e as entradas e saídas da operação. Ele também define a função de forma usada para inferência de forma de tensor.
- Implemente a operação em C++. A implementação de uma operação é conhecida como kernel e é a implementação concreta da especificação que você registrou na Etapa 1. Pode haver vários kernels para diferentes tipos ou arquiteturas de entrada/saída (por exemplo, CPUs, GPUs).
- Crie um wrapper Python (opcional). Este wrapper é a API pública usada para criar a operação em Python. Um wrapper padrão é gerado a partir do registro operacional, que pode ser usado diretamente ou adicionado.
- Escreva uma função para calcular gradientes para a operação (opcional).
- Teste a operação. Geralmente fazemos isso em Python por conveniência, mas você também pode testar a operação em C++. Se você definir gradientes, poderá verificá-los com Python
tf.test.compute_gradient_error
. Vejarelu_op_test.py
como um exemplo que testa as funções diretas de operadores do tipo Relu e seus gradientes.
Pré-requisitos
- Alguma familiaridade com C++.
- Deve ter instalado o binário do TensorFlow ou deve ter baixado a fonte do TensorFlow e ser capaz de construí-lo.
Definir a interface operacional
Você define a interface de uma operação registrando-a no sistema TensorFlow. No registro, você especifica o nome da sua operação, suas entradas (tipos e nomes) e saídas (tipos e nomes), bem como docstrings e quaisquer atributos que a operação possa exigir.
Para ver como isso funciona, suponha que você queira criar uma operação que pegue um tensor de int32
s e produza uma cópia do tensor, com todos os elementos, exceto o primeiro, definidos como zero. Para fazer isso, crie um arquivo chamado zero_out.cc
. Em seguida, adicione uma chamada à macro REGISTER_OP
que define a interface da sua operação:
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
Esta operação ZeroOut
leva um tensor to_zero
de inteiros de 32 bits como entrada e gera um tensor zeroed
de inteiros de 32 bits. A operação também usa uma função de forma para garantir que o tensor de saída tenha a mesma forma que o tensor de entrada. Por exemplo, se a entrada for um tensor de forma [10, 20], então esta função de forma especifica que a forma de saída também é [10, 20].
Implemente o kernel para a operação
Depois de definir a interface, forneça uma ou mais implementações da operação. Para criar um desses kernels, crie uma classe que estenda OpKernel
e substitua o método Compute
. O método Compute
fornece um argumento context
do tipo OpKernelContext*
, a partir do qual você pode acessar coisas úteis como os tensores de entrada e saída.
Adicione seu kernel ao arquivo que você criou acima. O kernel pode ser parecido com isto:
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value if possible.
if (N > 0) output_flat(0) = input(0);
}
};
Depois de implementar seu kernel, você o registra no sistema TensorFlow. No registro, você especifica diferentes restrições sob as quais este kernel será executado. Por exemplo, você pode ter um kernel feito para CPUs e outro separado para GPUs.
Para fazer isso na operação ZeroOut
, adicione o seguinte a zero_out.cc
:
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
Kernels de CPU multithread
Para escrever um kernel de CPU multithread, a função Shard em work_sharder.h
pode ser usada. Esta função fragmenta uma função de computação entre os threads configurados para serem usados para threading intra-op (consulte intra_op_parallelism_threads em config.proto
).
Núcleos de GPU
Um kernel GPU é implementado em duas partes: o OpKernel e o kernel CUDA e seu código de inicialização.
Às vezes, a implementação do OpKernel é comum entre um kernel de CPU e GPU, como na inspeção de entradas e na alocação de saídas. Nesse caso, uma implementação sugerida é:
- Defina o OpKernel modelado no Dispositivo e o tipo primitivo do tensor.
- Para fazer o cálculo real da saída, a função Compute chama uma estrutura de functor modelada.
- A especialização desse functor para CPUDevice é definida no mesmo arquivo, mas a especialização para GPUDevice é definida em um arquivo .cu.cc, pois será compilado com o compilador CUDA.
Aqui está um exemplo de implementação.
// kernel_example.h
#ifndef KERNEL_EXAMPLE_H_
#define KERNEL_EXAMPLE_H_
#include <unsupported/Eigen/CXX11/Tensor>
template <typename Device, typename T>
struct ExampleFunctor {
void operator()(const Device& d, int size, const T* in, T* out);
};
#if GOOGLE_CUDA
// Partially specialize functor for GpuDevice.
template <typename T>
struct ExampleFunctor<Eigen::GpuDevice, T> {
void operator()(const Eigen::GpuDevice& d, int size, const T* in, T* out);
};
#endif
#endif KERNEL_EXAMPLE_H_
// kernel_example.cc
#include "kernel_example.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
REGISTER_OP("Example")
.Attr("T: numbertype")
.Input("input: T")
.Output("input_times_two: T")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
// CPU specialization of actual computation.
template <typename T>
struct ExampleFunctor<CPUDevice, T> {
void operator()(const CPUDevice& d, int size, const T* in, T* out) {
for (int i = 0; i < size; ++i) {
out[i] = 2 * in[i];
}
}
};
// OpKernel definition.
// template parameter <T> is the datatype of the tensors.
template <typename Device, typename T>
class ExampleOp : public OpKernel {
public:
explicit ExampleOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
// Do the computation.
OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max,
errors::InvalidArgument("Too many elements in tensor"));
ExampleFunctor<Device, T>()(
context->eigen_device<Device>(),
static_cast<int>(input_tensor.NumElements()),
input_tensor.flat<T>().data(),
output_tensor->flat<T>().data());
}
};
// Register the CPU kernels.
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("Example").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ExampleOp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(int32);
// Register the GPU kernels.
#ifdef GOOGLE_CUDA
#define REGISTER_GPU(T) \
/* Declare explicit instantiations in kernel_example.cu.cc. */ \
extern template class ExampleFunctor<GPUDevice, T>; \
REGISTER_KERNEL_BUILDER( \
Name("Example").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
ExampleOp<GPUDevice, T>);
REGISTER_GPU(float);
REGISTER_GPU(int32);
#endif // GOOGLE_CUDA
// kernel_example.cu.cc
#ifdef GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "kernel_example.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;
// Define the CUDA kernel.
template <typename T>
__global__ void ExampleCudaKernel(const int size, const T* in, T* out) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
i += blockDim.x * gridDim.x) {
out[i] = 2 * __ldg(in + i);
}
}
// Define the GPU implementation that launches the CUDA kernel.
template <typename T>
void ExampleFunctor<GPUDevice, T>::operator()(
const GPUDevice& d, int size, const T* in, T* out) {
// Launch the cuda kernel.
//
// See core/util/gpu_kernel_helper.h for example of computing
// block count and thread_per_block count.
int block_count = 1024;
int thread_per_block = 20;
ExampleCudaKernel<T>
<<<block_count, thread_per_block, 0, d.stream()>>>(size, in, out);
}
// Explicitly instantiate functors for the types of OpKernels registered.
template struct ExampleFunctor<GPUDevice, float>;
template struct ExampleFunctor<GPUDevice, int32>;
#endif // GOOGLE_CUDA
Construa a biblioteca de operações
Compile a operação usando o compilador do sistema (instalação binária do TensorFlow)
Você deve ser capaz de compilar zero_out.cc
com um compilador C++
como g++
ou clang
disponível em seu sistema. O pacote PIP binário instala os arquivos de cabeçalho e a biblioteca necessária para compilar sua operação em locais específicos do sistema. No entanto, a biblioteca python do TensorFlow fornece a função get_include
para obter o diretório de cabeçalho, e o diretório get_lib
tem um objeto compartilhado para vincular. Aqui estão os resultados dessas funções em uma máquina Ubuntu.
$ python
>>> import tensorflow as tf
>>> tf.sysconfig.get_include()
'/usr/local/lib/python3.6/site-packages/tensorflow/include'
>>> tf.sysconfig.get_lib()
'/usr/local/lib/python3.6/site-packages/tensorflow'
Supondo que você tenha g++
instalado, aqui está a sequência de comandos que você pode usar para compilar sua operação em uma biblioteca dinâmica.
TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++14 -shared zero_out.cc -o zero_out.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2
No macOS, o sinalizador adicional "-undefined dynamic_lookup" é necessário ao criar o arquivo .so
.
Nota sobre a versão
gcc
>=5
: o gcc usa a nova ABI C++ desde a versão5
. O TensorFlow 2.8 e anteriores foram criados comgcc4
que usa a ABI mais antiga. Se você estiver usando essas versões do TensorFlow e tentando compilar sua biblioteca operacional comgcc>=5
, adicione-D_GLIBCXX_USE_CXX11_ABI=0
à linha de comando para tornar a biblioteca compatível com a ABI mais antiga. Os pacotes do TensorFlow 2.9+ são compatíveis com a ABI mais recente por padrão.
Compile a operação usando bazel (instalação de origem do TensorFlow)
Se você tiver fontes do TensorFlow instaladas, poderá usar o sistema de compilação do TensorFlow para compilar sua operação. Coloque um arquivo BUILD com a seguinte regra de compilação do Bazel no diretório tensorflow/core/user_ops
.
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
tf_custom_op_library(
name = "zero_out.so",
srcs = ["zero_out.cc"],
)
Execute o seguinte comando para construir zero_out.so
.
$ bazel build --config opt //tensorflow/core/user_ops:zero_out.so
Para compilar a operação Example
, com o Kernel CUDA, você precisa usar o parâmetro gpu_srcs
de tf_custom_op_library
. Coloque um arquivo BUILD com a seguinte regra de compilação do Bazel em uma nova pasta dentro do diretório tensorflow/core/user_ops
(por exemplo, "example_gpu").
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
tf_custom_op_library(
# kernel_example.cc kernel_example.cu.cc kernel_example.h
name = "kernel_example.so",
srcs = ["kernel_example.h", "kernel_example.cc"],
gpu_srcs = ["kernel_example.cu.cc", "kernel_example.h"],
)
Execute o seguinte comando para construir kernel_example.so
.
$ bazel build --config opt //tensorflow/core/user_ops/example_gpu:kernel_example.so
Use a operação em Python
A API TensorFlow Python fornece a função tf.load_op_library
para carregar a biblioteca dinâmica e registrar a operação com a estrutura TensorFlow. load_op_library
retorna um módulo Python que contém os wrappers Python para a operação e o kernel. Assim, depois de criar a operação, você pode fazer o seguinte para executá-la em Python:
import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
print(zero_out_module.zero_out([[1, 2], [3, 4]]).numpy())
# Prints
array([[1, 0], [0, 0]], dtype=int32)
Lembre-se de que a função gerada receberá um nome snake_case (para cumprir PEP8 ). Portanto, se sua operação for chamada ZeroOut
nos arquivos C++, a função python será chamada zero_out
.
Para disponibilizar a operação como uma função regular import
-able de um módulo Python, talvez seja útil ter a chamada load_op_library
em um arquivo fonte Python da seguinte maneira:
import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
zero_out = zero_out_module.zero_out
Verifique se a operação funciona
Uma boa maneira de verificar se você implementou sua operação com sucesso é escrever um teste para ela. Crie o arquivo zero_out_op_test.py
com o conteúdo:
import tensorflow as tf
class ZeroOutTest(tf.test.TestCase):
def testZeroOut(self):
zero_out_module = tf.load_op_library('./zero_out.so')
with self.test_session():
result = zero_out_module.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
if __name__ == "__main__":
tf.test.main()
Em seguida, execute seu teste (supondo que você tenha o tensorflow instalado):
$ python zero_out_op_test.py
Crie recursos avançados em sua operação
Agora que você sabe como criar uma operação e uma implementação básicas (e um tanto restritas), veremos algumas das coisas mais complicadas que você normalmente precisará incorporar à sua operação. Isso inclui:
- Verificações condicionais e validação
- Registro de operação
- Suporte para GPU
- Implemente o gradiente em Python
- Funções de forma em C++
Verificações condicionais e validação
O exemplo acima assumiu que a operação se aplicava a um tensor de qualquer formato. E se fosse aplicado apenas a vetores? Isso significa adicionar uma verificação à implementação do OpKernel acima.
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
errors::InvalidArgument("ZeroOut expects a 1-D vector."));
// ...
}
Isso afirma que a entrada é um vetor e retorna tendo definido o status InvalidArgument
se não for. A macro OP_REQUIRES
leva três argumentos:
- O
context
, que pode ser um ponteiroOpKernelContext
ouOpKernelConstruction
(consultetensorflow/core/framework/op_kernel.h
), para seu métodoSetStatus()
. - A condição. Por exemplo, existem funções para validar a forma de um tensor em
tensorflow/core/framework/tensor_shape.h
- O erro em si, que é representado por um objeto
Status
, consultetensorflow/core/platform/status.h
. UmStatus
possui um tipo (frequentementeInvalidArgument
, mas veja a lista de tipos) e uma mensagem. Funções para construir um erro podem ser encontradas emtensorflow/core/platform/errors.h
.
Alternativamente, se você quiser testar se um objeto Status
retornado de alguma função é um erro e, em caso afirmativo, retorná-lo, use OP_REQUIRES_OK
. Ambas as macros retornam da função em caso de erro.
Registro de operação
Atributos
As operações podem ter atributos, cujos valores são definidos quando a operação é adicionada a um gráfico. Eles são usados para configurar a operação e seus valores podem ser acessados tanto na implementação do kernel quanto nos tipos de entradas e saídas no registro da operação. Prefira usar uma entrada em vez de um atributo quando possível, pois as entradas são mais flexíveis. Isso ocorre porque os attrs são constantes e devem ser definidos no momento da construção do gráfico. Por outro lado, as entradas são tensores cujos valores podem ser dinâmicos; isto é, as entradas podem mudar a cada etapa, ser definidas usando um feed, etc. Attrs são usados para coisas que não podem ser feitas com entradas: qualquer configuração que afete a assinatura (número ou tipo de entradas ou saídas) ou que possa ' Não mude passo a passo.
Você define um attr ao registrar a operação, especificando seu nome e tipo usando o método Attr
, que espera uma especificação no formato:
<name>: <attr-type-expr>
onde <name>
começa com uma letra e pode ser composto por caracteres alfanuméricos e sublinhados, e <attr-type-expr>
é uma expressão de tipo no formato descrito abaixo .
Por exemplo, se quiser que a operação ZeroOut
preserve um índice especificado pelo usuário, em vez de apenas o elemento 0, você pode registrar a operação da seguinte forma:
REGISTER_OP("ZeroOut")
.Attr("preserve_index: int")
.Input("to_zero: int32")
.Output("zeroed: int32");
(Observe que o conjunto de tipos de atributos é diferente do tf.DType
usado para entradas e saídas.)
Seu kernel pode então acessar esse atributo em seu construtor através do parâmetro context
:
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
// Get the index of the value to preserve
OP_REQUIRES_OK(context,
context->GetAttr("preserve_index", &preserve_index_));
// Check that preserve_index is positive
OP_REQUIRES(context, preserve_index_ >= 0,
errors::InvalidArgument("Need preserve_index >= 0, got ",
preserve_index_));
}
void Compute(OpKernelContext* context) override {
// ...
}
private:
int preserve_index_;
};
que pode então ser usado no método Compute
:
void Compute(OpKernelContext* context) override {
// ...
// We're using saved attr to validate potentially dynamic input
// So we check that preserve_index is in range
OP_REQUIRES(context, preserve_index_ < input.dimension(0),
errors::InvalidArgument("preserve_index out of range"));
// Set all the elements of the output tensor to 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the requested input value
output_flat(preserve_index_) = input(preserve_index_);
}
Tipos de atributos
Os seguintes tipos são suportados em um atributo:
-
string
: Qualquer sequência de bytes (não é necessário ser UTF8). -
int
: Um inteiro assinado. -
float
: Um número de ponto flutuante. -
bool
: Verdadeiro ou falso. -
type
: um dos valores (não-ref) deDataType
. -
shape
: UmTensorShapeProto
. -
list(<type>)
: Uma lista de<type>
, onde<type>
é um dos tipos acima. Observe quelist(list(<type>))
é inválido.
Consulte também: op_def_builder.cc:FinalizeAttr
para obter uma lista definitiva.
Valores padrão e restrições
Os atributos podem ter valores padrão e alguns tipos de atributos podem ter restrições. Para definir um attr com restrições, você pode usar os seguintes <attr-type-expr>
s:
{'<string1>', '<string2>'}
: o valor deve ser uma string que tenha o valor <string1>
ou <string2>
. O nome do tipo, string
, está implícito quando você usa essa sintaxe. Isso emula um enum:
REGISTER_OP("EnumExample")
.Attr("e: {'apple', 'orange'}");
{<type1>, <type2>}
: o valor é do tipo type
e deve ser um de <type1>
ou <type2>
, onde <type1>
e <type2>
são suportados tf.DType
. Você não especifica que o tipo do attr é type
. Isso está implícito quando você tem uma lista de tipos em {...}
. Por exemplo, neste caso o attr t
é um tipo que deve ser int32
, float
ou bool
:
REGISTER_OP("RestrictedTypeExample")
.Attr("t: {int32, float, bool}");
Existem atalhos para restrições de tipo comuns:
-
numbertype
:type
de tipo restrito aos tipos numéricos (não string e não bool). -
realnumbertype
: comonumbertype
sem tipos complexos. -
quantizedtype
: comonumbertype
, mas apenas os tipos de números quantizados.
As listas específicas de tipos permitidos por estes são definidas pelas funções (como NumberTypes()
) em tensorflow/core/framework/types.h
. Neste exemplo, o attr t
deve ser um dos tipos numéricos:
REGISTER_OP("NumberType")
.Attr("t: numbertype");
Para esta operação:
tf.number_type(t=tf.int32) # Valid
tf.number_type(t=tf.bool) # Invalid
As listas podem ser combinadas com outras listas e tipos únicos. A operação a seguir permite que attr t
seja qualquer um dos tipos numéricos ou do tipo bool:
REGISTER_OP("NumberOrBooleanType")
.Attr("t: {numbertype, bool}");
Para esta operação:
tf.number_or_boolean_type(t=tf.int32) # Valid
tf.number_or_boolean_type(t=tf.bool) # Valid
tf.number_or_boolean_type(t=tf.string) # Invalid
int >= <n>
: O valor deve ser um int cujo valor seja maior ou igual a <n>
, onde <n>
é um número natural. Por exemplo, o registro de operação a seguir especifica que attr a
deve ter um valor que seja pelo menos 2
:
REGISTER_OP("MinIntExample")
.Attr("a: int >= 2");
list(<type>) >= <n>
: Uma lista do tipo <type>
cujo comprimento é maior ou igual a <n>
. Por exemplo, o registro de operação a seguir especifica que attr a
é uma lista de tipos ( int32
ou float
) e que deve haver pelo menos 3 deles:
REGISTER_OP("TypeListExample")
.Attr("a: list({int32, float}) >= 3");
Para definir um valor padrão para um atributo (tornando-o opcional no código gerado), adicione = <default>
ao final, como em:
REGISTER_OP("AttrDefaultExample")
.Attr("i: int = 0");
Além disso, tanto uma restrição quanto um valor padrão podem ser especificados:
REGISTER_OP("AttrConstraintAndDefaultExample")
.Attr("i: int >= 1 = 1");
A sintaxe suportada do valor padrão é a que seria usada na representação proto da definição GraphDef resultante.
Aqui estão alguns exemplos de como especificar um padrão para todos os tipos:
REGISTER_OP("AttrDefaultExampleForAllTypes")
.Attr("s: string = 'foo'")
.Attr("i: int = 0")
.Attr("f: float = 1.0")
.Attr("b: bool = true")
.Attr("ty: type = DT_INT32")
.Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
.Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
.Attr("l_empty: list(int) = []")
.Attr("l_int: list(int) = [2, 3, 5, 7]");
Observe em particular que os valores do tipo type
usam tf.DType
.
Polimorfismo
Polimorfismo de tipo
Para operações que podem receber tipos diferentes como entrada ou produzir tipos de saída diferentes, você pode especificar um atributo em um tipo de entrada ou saída no registro da operação. Normalmente você registraria um OpKernel
para cada tipo suportado.
Por exemplo, se você quiser que a operação ZeroOut
funcione em float
s além de int32
s, seu registro de operação pode ser semelhante a:
REGISTER_OP("ZeroOut")
.Attr("T: {float, int32}")
.Input("to_zero: T")
.Output("zeroed: T");
Seu registro de operação agora especifica que o tipo de entrada deve ser float
, ou int32
, e que sua saída será do mesmo tipo, já que ambos possuem tipo T
.
Nomeação
Entradas, saídas e atributos geralmente devem receber nomes de Snake_case. A única exceção são os atributos que são usados como o tipo de entrada ou no tipo de saída. Esses atributos podem ser inferidos quando a operação é adicionada ao gráfico e, portanto, não aparecem na função da operação. Por exemplo, esta última definição de ZeroOut irá gerar uma função Python semelhante a:
def zero_out(to_zero, name=None):
"""...
Args:
to_zero: A `Tensor`. Must be one of the following types:
`float32`, `int32`.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `to_zero`.
"""
Se to_zero
receber um tensor int32
, então T
será automaticamente definido como int32
(bem, na verdade DT_INT32
). Esses atributos inferidos recebem nomes em maiúscula ou CamelCase.
Compare isso com uma operação que possui um tipo attr que determina o tipo de saída:
REGISTER_OP("StringToNumber")
.Input("string_tensor: string")
.Output("output: out_type")
.Attr("out_type: {float, int32} = DT_FLOAT");
.Doc(R"doc(
Converts each string in the input Tensor to the specified numeric type.
)doc");
Neste caso, o usuário deve especificar o tipo de saída, como no Python gerado:
def string_to_number(string_tensor, out_type=None, name=None):
"""Converts each string in the input Tensor to the specified numeric type.
Args:
string_tensor: A `Tensor` of type `string`.
out_type: An optional `tf.DType` from: `tf.float32, tf.int32`.
Defaults to `tf.float32`.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `out_type`.
"""
Exemplo de polimorfismo de tipo
#include "tensorflow/core/framework/op_kernel.h"
class ZeroOutInt32Op : public OpKernel {
// as before
};
class ZeroOutFloatOp : public OpKernel {
public:
explicit ZeroOutFloatOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<float>();
// Create an output tensor
Tensor* output = NULL;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_tensor.shape(), &output));
auto output_flat = output->template flat<float>();
// Set all the elements of the output tensor to 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value
if (N > 0) output_flat(0) = input(0);
}
};
// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<int32>("T"),
ZeroOutInt32Op);
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
ZeroOutFloatOp);
Para preservar a compatibilidade com versões anteriores , você deve especificar um valor padrão ao adicionar um attr a uma operação existente:
REGISTER_OP("ZeroOut")
.Attr("T: {float, int32} = DT_INT32")
.Input("to_zero: T")
.Output("zeroed: T")
Digamos que você queira adicionar mais tipos, digamos double
:
REGISTER_OP("ZeroOut")
.Attr("T: {float, double, int32}")
.Input("to_zero: T")
.Output("zeroed: T");
Em vez de escrever outro OpKernel
com código redundante como acima, muitas vezes você poderá usar um modelo C++. Você ainda terá um registro de kernel (chamada REGISTER_KERNEL_BUILDER
) por sobrecarga.
template <typename T>
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<T>();
// Create an output tensor
Tensor* output = NULL;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_tensor.shape(), &output));
auto output_flat = output->template flat<T>();
// Set all the elements of the output tensor to 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value
if (N > 0) output_flat(0) = input(0);
}
};
// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<int32>("T"),
ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<double>("T"),
ZeroOutOp<double>);
Se você tiver mais de algumas sobrecargas, poderá colocar o registro em uma macro.
#include "tensorflow/core/framework/op_kernel.h"
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ZeroOutOp<type>)
REGISTER_KERNEL(int32);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
#undef REGISTER_KERNEL
Dependendo da lista de tipos para os quais você está registrando o kernel, você poderá usar uma macro fornecida por tensorflow/core/framework/register_types.h
:
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
REGISTER_OP("ZeroOut")
.Attr("T: realnumbertype")
.Input("to_zero: T")
.Output("zeroed: T");
template <typename T>
class ZeroOutOp : public OpKernel { ... };
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ZeroOutOp<type>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
Listar entradas e saídas
Além de poder aceitar ou produzir diferentes tipos, as operações podem consumir ou produzir um número variável de tensores.
No próximo exemplo, o attr T
contém uma lista de tipos e é usado como o tipo de in
e out
. A entrada e a saída são listas de tensores desse tipo (e o número e os tipos de tensores na saída são iguais aos da entrada, pois ambos possuem o tipo T
).
REGISTER_OP("PolymorphicListExample")
.Attr("T: list(type)")
.Input("in: T")
.Output("out: T");
Você também pode colocar restrições sobre quais tipos podem ser especificados na lista. Neste próximo caso, a entrada é uma lista de tensores float
e double
. A operação aceita, por exemplo, tipos de entrada (float, double, float)
e nesse caso o tipo de saída também seria (float, double, float)
.
REGISTER_OP("ListTypeRestrictionExample")
.Attr("T: list({float, double})")
.Input("in: T")
.Output("out: T");
Se quiser que todos os tensores de uma lista sejam do mesmo tipo, você pode fazer algo como:
REGISTER_OP("IntListInputExample")
.Attr("N: int")
.Input("in: N * int32")
.Output("out: int32");
Isto aceita uma lista de tensores int32
e usa um int
attr N
para especificar o comprimento da lista.
Isso também pode ser feito de tipo polimórfico . No próximo exemplo, a entrada é uma lista de tensores (com comprimento "N"
) do mesmo tipo (mas não especificado) ( "T"
), e a saída é um único tensor de tipo correspondente:
REGISTER_OP("SameListInputExample")
.Attr("N: int")
.Attr("T: type")
.Input("in: N * T")
.Output("out: T");
Por padrão, as listas de tensores têm um comprimento mínimo de 1. Você pode alterar esse padrão usando uma restrição ">="
no attr correspondente . Neste próximo exemplo, a entrada é uma lista de pelo menos 2 tensores int32
:
REGISTER_OP("MinLengthIntListExample")
.Attr("N: int >= 2")
.Input("in: N * int32")
.Output("out: int32");
A mesma sintaxe funciona com atributos "list(type)"
:
REGISTER_OP("MinimumLengthPolymorphicListExample")
.Attr("T: list(type) >= 3")
.Input("in: T")
.Output("out: T");
Entradas e saídas
Para resumir o que foi dito acima, um registro operacional pode ter múltiplas entradas e saídas:
REGISTER_OP("MultipleInsAndOuts")
.Input("y: int32")
.Input("z: float")
.Output("a: string")
.Output("b: int32");
Cada especificação de entrada ou saída tem o formato:
<name>: <io-type-expr>
onde <name>
começa com uma letra e pode ser composto por caracteres alfanuméricos e sublinhados. <io-type-expr>
é uma das seguintes expressões de tipo:
<type>
, onde<type>
é um tipo de entrada suportado (por exemplo,float
,int32
,string
). Isso especifica um único tensor do tipo fornecido.Veja
tf.DType
.REGISTER_OP("BuiltInTypesExample") .Input("integers: int32") .Input("complex_numbers: complex64");
<attr-type>
, onde<attr-type>
é o nome de um Attr comtype
oulist(type)
(com uma possível restrição de tipo). Esta sintaxe permite operações polimórficas .REGISTER_OP("PolymorphicSingleInput") .Attr("T: type") .Input("in: T"); REGISTER_OP("RestrictedPolymorphicSingleInput") .Attr("T: {int32, int64}") .Input("in: T");
Referenciar um atributo do tipo
list(type)
permite aceitar uma sequência de tensores.REGISTER_OP("ArbitraryTensorSequenceExample") .Attr("T: list(type)") .Input("in: T") .Output("out: T"); REGISTER_OP("RestrictedTensorSequenceExample") .Attr("T: list({int32, int64})") .Input("in: T") .Output("out: T");
Observe que o número e os tipos de tensores na saída
out
são os mesmos que na entradain
, pois ambos são do tipoT
.Para uma sequência de tensores do mesmo tipo:
<number> * <type>
, onde<number>
é o nome de um Attr com tipoint
. O<type>
pode sertf.DType
ou o nome de um attr com tipotype
. Como exemplo do primeiro, esta operação aceita uma lista de tensoresint32
:REGISTER_OP("Int32SequenceExample") .Attr("NumTensors: int") .Input("in: NumTensors * int32")
Considerando que esta operação aceita uma lista de tensores de qualquer tipo, desde que sejam todos iguais:
REGISTER_OP("SameTypeSequenceExample") .Attr("NumTensors: int") .Attr("T: type") .Input("in: NumTensors * T")
Para uma referência a um tensor:
Ref(<type>)
, onde<type>
é um dos tipos anteriores.
Qualquer atributo usado no tipo de entrada será inferido. Por convenção, esses atributos inferidos usam nomes maiúsculos (como T
ou N
). Caso contrário, entradas, saídas e atributos terão nomes como parâmetros de função (por exemplo, num_outputs
). Para obter mais detalhes, consulte a seção anterior sobre nomenclatura .
Para obter mais detalhes, consulte tensorflow/core/framework/op_def_builder.h
.
Compatibilidade com versões anteriores
Vamos supor que você escreveu uma operação personalizada e agradável e a compartilhou com outras pessoas, para que você tenha clientes satisfeitos usando sua operação. No entanto, você gostaria de fazer alterações na operação de alguma forma.
Em geral, as alterações nas especificações existentes e com check-in devem ser compatíveis com versões anteriores: alterar a especificação de uma operação não deve quebrar os buffers de protocolo GraphDef
serializados anteriores construídos a partir de especificações mais antigas. Os detalhes da compatibilidade GraphDef
são descritos aqui .
Existem várias maneiras de preservar a compatibilidade com versões anteriores.
Quaisquer novos atributos adicionados a uma operação devem ter valores padrão definidos e, com esse valor padrão, a operação deve ter o comportamento original. Para alterar uma operação de não polimórfica para polimórfica, você deve fornecer um valor padrão ao novo tipo attr para preservar a assinatura original por padrão. Por exemplo, se sua operação foi:
REGISTER_OP("MyGeneralUnaryOp") .Input("in: float") .Output("out: float");
você pode torná-lo polimórfico de maneira compatível com versões anteriores usando:
REGISTER_OP("MyGeneralUnaryOp") .Input("in: T") .Output("out: T") .Attr("T: numerictype = DT_FLOAT");
Você pode tornar com segurança uma restrição em um atributo menos restritiva. Por exemplo, você pode alterar de
{int32, int64}
para{int32, int64, float}
outype
. Ou você pode mudar de{"apple", "orange"}
para{"apple", "banana", "orange"}
oustring
.Você pode alterar entradas/saídas únicas em entradas/saídas de lista, desde que o padrão do tipo de lista corresponda à assinatura antiga.
Você pode adicionar uma nova entrada/saída de lista, se o padrão for vazio.
Namespace todas as novas operações que você criar, prefixando os nomes das operações com algo exclusivo para o seu projeto. Isso evita que sua operação colida com quaisquer operações que possam ser incluídas em versões futuras do TensorFlow.
Planeje com antecedência! Tente antecipar usos futuros para a operação. Algumas alterações de assinatura não podem ser feitas de maneira compatível (por exemplo, transformar uma lista do mesmo tipo em uma lista de tipos variados).
A lista completa de alterações seguras e inseguras pode ser encontrada em tensorflow/core/framework/op_compatibility_test.cc
. Se você não conseguir fazer a alteração em uma operação compatível com versões anteriores, crie uma nova operação com um novo nome e a nova semântica.
Observe também que, embora essas alterações possam manter a compatibilidade GraphDef
, o código Python gerado pode mudar de uma forma que não é compatível com chamadores antigos. A API Python pode ser mantida compatível por meio de mudanças cuidadosas em um wrapper Python escrito à mão, mantendo a assinatura antiga, exceto possivelmente adicionando novos argumentos opcionais ao final. Geralmente, alterações incompatíveis só podem ser feitas quando o TensorFlow altera versões principais e devem estar em conformidade com a semântica de versão GraphDef
.
Suporte para GPU
Você pode implementar diferentes OpKernels e registrar um para CPU e outro para GPU, assim como você pode registrar kernels para diferentes tipos . Existem vários exemplos de kernels com suporte a GPU em tensorflow/core/kernels/
. Observe que alguns kernels têm uma versão de CPU em um arquivo .cc
, uma versão de GPU em um arquivo que termina em _gpu.cu.cc
e algum código compartilhado em comum em um arquivo .h
.
Por exemplo, o tf.pad
tem tudo, menos o kernel da GPU em tensorflow/core/kernels/pad_op.cc
. O kernel da GPU está em tensorflow/core/kernels/pad_op_gpu.cu.cc
e o código compartilhado é uma classe de modelo definida em tensorflow/core/kernels/pad_op.h
. Organizamos o código dessa maneira por dois motivos: permite compartilhar código comum entre as implementações de CPU e GPU e coloca a implementação de GPU em um arquivo separado para que possa ser compilada apenas pelo compilador de GPU.
Uma coisa a notar, mesmo quando a versão do kernel da GPU do pad
é usada, ele ainda precisa de sua entrada de "paddings"
na memória da CPU. Para marcar que as entradas ou saídas são mantidas na CPU, adicione uma chamada HostMemory()
ao registro do kernel, por exemplo:
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("Pad") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("paddings"), \
PadOp<GPUDevice, T>)
Compilando o kernel para o dispositivo GPU
Veja cuda_op_kernel.cu.cc para obter um exemplo que usa um kernel CUDA para implementar uma operação. A tf_custom_op_library
aceita um argumento gpu_srcs
no qual a lista de arquivos de origem contendo os kernels CUDA (arquivos *.cu.cc
) pode ser especificada. Para uso com uma instalação binária do TensorFlow, os kernels CUDA devem ser compilados com o compilador nvcc
da NVIDIA. Aqui está a sequência de comandos que você pode usar para compilar cuda_op_kernel.cu.cc e cuda_op_kernel.cc em uma única biblioteca carregável dinamicamente:
nvcc -std=c++14 -c -o cuda_op_kernel.cu.o cuda_op_kernel.cu.cc \
${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC
g++ -std=c++14 -shared -o cuda_op_kernel.so cuda_op_kernel.cc \
cuda_op_kernel.cu.o ${TF_CFLAGS[@]} -fPIC -lcudart ${TF_LFLAGS[@]}
cuda_op_kernel.so
produzido acima pode ser carregado normalmente em Python, usando a função tf.load_op_library
.
Observe que se suas bibliotecas CUDA não estiverem instaladas em /usr/local/lib64
, você precisará especificar o caminho explicitamente no segundo comando (g++) acima. Por exemplo, adicione -L /usr/local/cuda-8.0/lib64/
se seu CUDA estiver instalado em /usr/local/cuda-8.0
.
Implemente o gradiente em Python
Dado um gráfico de operações, o TensorFlow usa diferenciação automática (backpropagation) para adicionar novas operações que representam gradientes em relação às operações existentes. Para fazer a diferenciação automática funcionar para novas operações, você deve registrar uma função de gradiente que calcula gradientes em relação às entradas das operações, dados gradientes em relação às saídas das operações.
Matematicamente, se uma operação computa \(y = f(x)\) a operação de gradiente registrada converte gradientes \(\partial L/ \partial y\) de perda \(L\) em relação a\(y\) em gradientes \(\partial L/ \partial x\) em relação a \(x\) através da regra da cadeia:
\[\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial f}{\partial x}.\]
No caso de ZeroOut
, apenas uma entrada na entrada afeta a saída, portanto, o gradiente em relação à entrada é um tensor esparso "one hot". Isso é expresso da seguinte forma:
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
@ops.RegisterGradient("ZeroOut")
def _zero_out_grad(op, grad):
"""The gradients for `zero_out`.
Args:
op: The `zero_out` `Operation` that we are differentiating, which we can use
to find the inputs and outputs of the original op.
grad: Gradient with respect to the output of the `zero_out` op.
Returns:
Gradients with respect to the input of `zero_out`.
"""
to_zero = op.inputs[0]
shape = array_ops.shape(to_zero)
index = array_ops.zeros_like(shape)
first_grad = array_ops.reshape(grad, [-1])[0]
to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
return [to_zero_grad] # List of one Tensor, since we have one input
Detalhes sobre o registro de funções gradientes com tf.RegisterGradient
:
Para uma operação com uma saída, a função gradiente pegará um
tf.Operation
,op
e umtf.Tensor
grad
e construirá novas operações a partir dos tensoresop.inputs[i]
,op.outputs[i]
egrad
. Informações sobre quaisquer atributos podem ser encontradas emtf.Operation.get_attr
.Se a operação tiver múltiplas saídas, a função gradiente usará
op
egrads
, ondegrads
é uma lista de gradientes em relação a cada saída. O resultado da função gradiente deve ser uma lista de objetosTensor
representando os gradientes em relação a cada entrada.Se não houver um gradiente bem definido para alguma entrada, como para entradas inteiras usadas como índices, o gradiente retornado correspondente deverá ser
None
. Por exemplo, para uma operação que usa um tensor de ponto flutuantex
e um índice inteiroi
, a função gradientereturn [x_grad, None]
.Se não houver nenhum gradiente significativo para a operação, muitas vezes você não precisará registrar nenhum gradiente e, desde que o gradiente da operação nunca seja necessário, você ficará bem. Em alguns casos, uma operação não possui gradiente bem definido, mas pode estar envolvida no cálculo do gradiente. Aqui você pode usar
ops.NotDifferentiable
para propagar zeros automaticamente para trás.
Observe que no momento em que a função gradiente é chamada, apenas o gráfico de fluxo de dados de operações está disponível, e não os dados do tensor em si. Assim, todos os cálculos devem ser realizados usando outras operações de tensorflow, para serem executadas em tempo de execução do gráfico.
Adicione dicas de tipo ao registrar o gradiente personalizado para um tipo de operação para tornar o código mais legível, depurável, mais fácil de manter e mais robusto por meio da validação de dados. Por exemplo, ao usar um op
como parâmetro em uma função, especifique que a função gradiente usará tf.Operation
como tipo de parâmetro.
Funções de forma em C++
A API TensorFlow possui um recurso chamado "inferência de forma" que fornece informações sobre as formas dos tensores sem a necessidade de executar o gráfico. A inferência de forma é suportada por "funções de forma" que são registradas para cada tipo de operação na declaração C++ REGISTER_OP
e desempenham duas funções: afirmar que as formas das entradas são compatíveis durante a construção do gráfico e especificar as formas das saídas.
Funções de forma são definidas como operações na classe shape_inference::InferenceContext
. Por exemplo, na função de forma para ZeroOut:
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
c->set_output(0, c->input(0));
declara que o formato da primeira saída deve ser definido como o formato da primeira entrada. Se a saída for selecionada por seu índice como no exemplo acima, o segundo parâmetro de set_output
deverá ser um objeto ShapeHandle
. Você pode criar um objeto ShapeHandle
vazio por seu construtor padrão. O objeto ShapeHandle
para uma entrada com índice idx
pode ser obtido por c->input(idx)
.
Há uma série de funções de forma comuns que se aplicam a muitas operações, como shape_inference::UnchangedShape
que pode ser encontrada em common_shape_fns.h e usada da seguinte forma:
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn(::tensorflow::shape_inference::UnchangedShape);
Uma função de forma também pode restringir a forma de uma entrada. Para a versão do ZeroOut
com restrição de forma vetorial , a função de forma seria a seguinte:
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
::tensorflow::shape_inference::ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
c->set_output(0, input);
return Status::OK();
});
A chamada WithRank
valida que a forma de entrada c->input(0)
tem uma forma com exatamente uma dimensão (ou se a forma de entrada for desconhecida, a forma de saída será um vetor com uma dimensão desconhecida).
Se sua operação for polimórfica com múltiplas entradas , você pode usar membros de InferenceContext
para determinar o número de formas a serem verificadas e Merge
para validar se as formas são todas compatíveis (como alternativa, acesse atributos que indicam os comprimentos, com InferenceContext::GetAttr
, que fornece acesso aos atributos da operação).
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
::tensorflow::shape_inference::ShapeHandle input;
::tensorflow::shape_inference::ShapeHandle output;
for (size_t i = 0; i < c->num_inputs(); ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &input));
TF_RETURN_IF_ERROR(c->Merge(output, input, &output));
}
c->set_output(0, output);
return Status::OK();
});
Como a inferência de forma é um recurso opcional e as formas dos tensores podem variar dinamicamente, as funções de forma devem ser robustas para informações de forma incompletas para qualquer uma das entradas. O método Merge
em InferenceContext
permite que o chamador afirme que duas formas são iguais, mesmo que uma ou ambas não tenham informações completas. As funções de forma são definidas para todas as operações principais do TensorFlow e fornecem muitos exemplos de uso diferentes.
A classe InferenceContext
possui diversas funções que podem ser usadas para definir manipulações de funções de forma. Por exemplo, você pode validar se uma dimensão específica tem um valor muito específico usando InferenceContext::Dim
e InferenceContext::WithValue
; você pode especificar que uma dimensão de saída é a soma/produto de duas dimensões de entrada usando InferenceContext::Add
e InferenceContext::Multiply
. Consulte a classe InferenceContext
para todas as diversas manipulações de formas que você pode especificar. O exemplo a seguir define o formato da primeira saída como (n, 3), onde a primeira entrada tem formato (n, ...)
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->Matrix(c->Dim(c->input(0), 0), 3));
return Status::OK();
});
Se você tiver uma função de forma complicada, considere adicionar um teste para validar se várias combinações de formas de entrada produzem as combinações de formas de saída esperadas. Você pode ver exemplos de como escrever esses testes em alguns de nossos principais testes de operações . (A sintaxe de INFER_OK
e INFER_ERROR
é um pouco enigmática, mas tente ser compacto ao representar as especificações de formato de entrada e saída em testes. Por enquanto, consulte os comentários ao redor desses testes para ter uma noção da especificação da string de formato).
Crie um pacote pip para sua operação personalizada
Para construir um pacote pip
para sua operação, consulte o exemplo tensorflow/custom-op . Este guia mostra como criar operações personalizadas a partir do pacote pip do TensorFlow em vez de criar o TensorFlow a partir da origem.