أداء أفضل مع وظيفة tf

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثبتحميل دفتر

في TensorFlow 2 ، يتم تشغيل التنفيذ الحثيث افتراضيًا. واجهة المستخدم سهلة الاستخدام ومرنة (تشغيل العمليات لمرة واحدة أسهل بكثير وأسرع) ، ولكن هذا يمكن أن يأتي على حساب الأداء وقابلية النشر.

يمكنك استخدام tf.function لعمل رسوم بيانية من برامجك. إنها أداة تحويل تقوم بإنشاء رسوم بيانية لتدفق البيانات مستقلة عن Python من كود Python الخاص بك. سيساعدك هذا في إنشاء نماذج عالية الأداء ومحمولة ، ويلزم استخدام SavedModel .

سيساعدك هذا الدليل على تصور كيفية عمل tf.function تحت الغطاء ، حتى تتمكن من استخدامها بفعالية.

النقاط الرئيسية والتوصيات هي:

  • تصحيح الأخطاء في الوضع الحثيث ، ثم تزيينها باستخدام @tf.function .
  • لا تعتمد على تأثيرات Python الجانبية مثل تحور الكائن أو إلحاق القائمة.
  • تعمل tf.function بشكل أفضل مع عمليات TensorFlow ؛ يتم تحويل استدعاءات NumPy و Python إلى ثوابت.

يثبت

import tensorflow as tf

حدد وظيفة مساعدة لتوضيح أنواع الأخطاء التي قد تواجهها:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

الأساسيات

إستعمال

Function التي تحددها (على سبيل المثال من خلال تطبيق decorator @tf.function ) تشبه تمامًا عملية TensorFlow الأساسية: يمكنك تنفيذها بفارغ الصبر ؛ يمكنك حساب التدرجات. وما إلى ذلك وهلم جرا.

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

يمكنك استخدام Function داخل Function أخرى.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

يمكن أن تكون Function أسرع من التعليمات البرمجية المتحمسة ، خاصة بالنسبة للرسوم البيانية التي تحتوي على العديد من العمليات الصغيرة. ولكن بالنسبة للرسوم البيانية التي تحتوي على عدد قليل من العمليات باهظة الثمن (مثل التلافيف) ، فقد لا ترى الكثير من التسريع.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735
Function conv: 0.005791576000774512
Note how there's not much difference in performance for convolutions

اقتفاء أثر

يشرح هذا القسم كيفية Function تحت الغطاء ، بما في ذلك تفاصيل التنفيذ التي قد تتغير في المستقبل . ومع ذلك ، بمجرد فهمك لماذا ومتى يحدث التتبع ، يصبح استخدام tf.function بشكل فعال أسهل بكثير!

ما هو "التعقب"؟

تقوم إحدى Function بتشغيل برنامجك في رسم بياني TensorFlow . ومع ذلك ، لا يمكن tf.Graph أن يمثل كل الأشياء التي تكتبها في برنامج TensorFlow الشغوف. على سبيل المثال ، تدعم Python تعدد الأشكال ، لكن tf.Graph تتطلب مدخلاتها أن يكون لها نوع وأبعاد بيانات محددة. أو يمكنك أداء مهام جانبية مثل قراءة وسيطات سطر الأوامر ، أو إثارة خطأ ، أو العمل باستخدام كائن Python أكثر تعقيدًا ؛ لا يمكن تشغيل أي من هذه الأشياء في رسم tf.Graph .

تعمل Function على سد هذه الفجوة عن طريق فصل الكود الخاص بك على مرحلتين:

1) في المرحلة الأولى ، يشار إليها باسم " التتبع " ، تقوم Function بإنشاء tf.Graph جديد. يعمل كود Python بشكل طبيعي ، ولكن يتم تأجيل جميع عمليات TensorFlow (مثل إضافة اثنين من Tensors ): يتم التقاطها بواسطة tf.Graph ولا يتم تشغيلها.

2) في المرحلة الثانية ، يتم تشغيل رسم tf.Graph يحتوي على كل ما تم تأجيله في المرحلة الأولى. هذه المرحلة أسرع بكثير من مرحلة التعقب.

اعتمادًا على مدخلاتها ، لن تعمل Function دائمًا في المرحلة الأولى عندما يتم استدعاؤها. راجع "قواعد التتبع" أدناه للحصول على فكرة أفضل عن كيفية اتخاذ هذا القرار. تخطي المرحلة الأولى وتنفيذ المرحلة الثانية فقط هو ما يمنحك أداء TensorFlow العالي.

عندما تقرر Function التتبع ، تتبع مرحلة التتبع مباشرة المرحلة الثانية ، لذا فإن استدعاء Function يؤدي إلى إنشاء tf.Graph . سترى لاحقًا كيف يمكنك تشغيل مرحلة التتبع فقط باستخدام get_concrete_function .

عند تمرير وسيطات من أنواع مختلفة إلى Function ، يتم تشغيل كلا المرحلتين:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

لاحظ أنه إذا قمت باستدعاء Function بشكل متكرر بنفس نوع الوسيطة ، فسيتخطى TensorFlow مرحلة التتبع ويعيد استخدام الرسم البياني الذي تم تتبعه مسبقًا ، حيث سيكون الرسم البياني الذي تم إنشاؤه متطابقًا.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

يمكنك استخدام pretty_printed_concrete_signatures() لمشاهدة جميع الآثار المتاحة:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

حتى الآن ، رأيت أن tf.function تُنشئ طبقة إرسال ديناميكية مخزنة مؤقتًا عبر منطق تتبع الرسم البياني في TensorFlow. لتكون أكثر تحديدًا حول المصطلحات:

  • tf.Graph هو التمثيل الخام غير الملائم للغة والمحمول لحساب TensorFlow.
  • تقوم وظيفة ConcreteFunction بلف رسم tf.Graph .
  • تقوم إحدى Function بإدارة ذاكرة التخزين المؤقت للوظائف ConcreteFunction وتختار العنصر المناسب لإدخالاتك.
  • تلف دالة tf.function دالة بايثون ، وتعيد كائن Function .
  • ينشئ التتبع رسمًا tf.Graph . ويلفه في دالة ConcreteFunction ، تُعرف أيضًا باسم التتبع.

قواعد التعقب

تحدد Function ما إذا كان سيتم إعادة استخدام دالة ConcreteFunction التي تم تتبعها عن طريق حساب مفتاح ذاكرة التخزين المؤقت من args و kwargs للإدخال. مفتاح ذاكرة التخزين المؤقت هو مفتاح يحدد وظيفة ConcreteFunction بناءً على args and kwargs لاستدعاء Function ، وفقًا للقواعد التالية (والتي قد تتغير):

  • المفتاح الذي تم إنشاؤه لـ tf.Tensor هو شكله ونوعه.
  • المفتاح الذي تم إنشاؤه tf.Variable هو معرف متغير فريد.
  • المفتاح الذي يتم إنشاؤه لبايثون بدائي (مثل int ، float ، str ) هو قيمته.
  • المفتاح الذي تم إنشاؤه من أجل dict s و list s و tuple s و namedtuple s و attr s هو المجموعة المسطحة لمفاتيح الأوراق (انظر nest.flatten ). (نتيجة لهذا التسطيح ، فإن استدعاء دالة محددة بهيكل متداخل مختلف عن تلك المستخدمة أثناء التتبع سيؤدي إلى خطأ في النوع).
  • بالنسبة لجميع أنواع Python الأخرى ، يكون المفتاح فريدًا للكائن. بهذه الطريقة يتم تتبع وظيفة أو طريقة بشكل مستقل لكل حالة يتم استدعاؤها معها.

السيطرة على الاسترداد

يساعد الاسترداد ، وهو عندما تقوم وظيفتك بإنشاء أكثر من تتبع ، على ضمان أن Function يقوم بإنشاء الرسوم البيانية الصحيحة لكل مجموعة من المدخلات. ومع ذلك ، فإن البحث عن المفقودين عملية مكلفة! إذا أعادت وظيفتك تتبع رسم بياني جديد لكل مكالمة ، فستجد أن شفرتك يتم تنفيذها بشكل أبطأ مما لو لم تستخدم Function tf.function .

للتحكم في سلوك التتبع ، يمكنك استخدام الأساليب التالية:

  • حدد input_signature في tf.function للحد من التتبع.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
  • حدد بُعدًا [بلا] في tf.TensorSpec للسماح بالمرونة في إعادة استخدام التتبع.

    نظرًا لأن TensorFlow يطابق الموترات بناءً على شكلها ، فإن استخدام بُعد None كحرف بدل سيسمح لـ Function s بإعادة استخدام الآثار لمدخلات متغيرة الحجم. يمكن أن يحدث إدخال متغير الحجم إذا كان لديك تسلسلات ذات أطوال مختلفة ، أو صور بأحجام مختلفة لكل دفعة (انظر دروس Transformer و Deep Dream على سبيل المثال).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • حجج بايثون إلى Tensors لتقليل الاسترداد.

    في كثير من الأحيان ، يتم استخدام وسيطات Python للتحكم في المعلمات الفائقة وإنشاءات الرسم البياني - على سبيل المثال ، num_layers=10 أو training=True أو nonlinearity='relu' . لذا ، إذا تغيرت حجة بايثون ، فمن المنطقي أن تضطر إلى إعادة تتبع الرسم البياني.

    ومع ذلك ، من الممكن ألا يتم استخدام وسيطة بايثون للتحكم في إنشاء الرسم البياني. في هذه الحالات ، يمكن للتغيير في قيمة Python أن يؤدي إلى ارتداد غير ضروري. خذ ، على سبيل المثال ، حلقة التدريب هذه ، والتي سيقوم AutoGraph بفكها ديناميكيًا. على الرغم من الآثار المتعددة ، فإن الرسم البياني الذي تم إنشاؤه هو في الواقع متطابق ، لذا فإن التصحيح غير ضروري.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

إذا كنت بحاجة إلى فرض الاسترداد ، فأنشئ Function جديدة. كائنات Function المنفصلة مضمونة لعدم مشاركة الآثار.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

الحصول على وظائف محددة

في كل مرة يتم فيها تتبع دالة ، يتم إنشاء وظيفة ملموسة جديدة. يمكنك الحصول على دالة ملموسة مباشرة باستخدام get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

تعرض طباعة ConcreteFunction ملخصًا لوسائط الإدخال (مع الأنواع) ونوع الإخراج الخاص بها.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

يمكنك أيضًا استرداد توقيع الوظيفة الملموسة مباشرةً.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

سيؤدي استخدام تتبع ملموس مع أنواع غير متوافقة إلى حدوث خطأ

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

قد تلاحظ أن حجج بايثون يتم التعامل معها بشكل خاص في توقيع إدخال الوظيفة الملموسة. قبل TensorFlow 2.3 ، تمت إزالة حجج Python ببساطة من توقيع الوظيفة الملموسة. بدءًا من TensorFlow 2.3 ، تظل وسيطات Python في التوقيع ، ولكنها مقيدة بأخذ القيمة المحددة أثناء التتبع.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
l10n-placeholder37 l10n-placeholder38l10n-placeholder35 l10n-placeholder36
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.

الحصول على الرسوم البيانية

كل وظيفة ملموسة عبارة عن غلاف قابل للاستدعاء حول رسم tf.Graph . على الرغم من أن استرداد كائن tf.Graph الفعلي ليس شيئًا ستحتاج إلى القيام به عادةً ، إلا أنه يمكنك الحصول عليه بسهولة من أي وظيفة ملموسة.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

تصحيح

بشكل عام ، يعد تصحيح الأخطاء في وضع التشغيل أسهل من داخل tf.function . يجب عليك التأكد من أن التعليمات البرمجية الخاصة بك تنفذ خالية من الأخطاء في الوضع الحثيث قبل التزيين tf.function . للمساعدة في عملية التصحيح ، يمكنك استدعاء tf.config.run_functions_eagerly(True) لتعطيل tf.function وإعادة تمكينها بشكل عام.

عند تعقب المشكلات التي تظهر فقط داخل tf.function ، إليك بعض النصائح:

  • لا يتم تنفيذ مكالمات print Python القديمة البسيطة إلا أثناء التتبع ، مما يساعدك على تعقب (إعادة) تتبع وظيفتك.
  • سيتم تنفيذ مكالمات tf.print في كل مرة ، ويمكن أن تساعدك على تعقب القيم الوسيطة أثناء التنفيذ.
  • tf.debugging.enable_check_numerics هي طريقة سهلة لتعقب مكان إنشاء NaNs و Inf.
  • يمكن أن يساعدك pdb ( مصحح أخطاء Python ) في فهم ما يحدث أثناء التتبع. (تحذير: سوف ينقلك pdb إلى كود المصدر المحول AutoGraph.)

تحويلات AutoGraph

AutoGraph هي مكتبة تعمل بشكل افتراضي في tf.function ، وتحول مجموعة فرعية من كود Python المتهور إلى عمليات TensorFlow المتوافقة مع الرسم البياني. يتضمن ذلك التحكم في التدفق مثل if ، for ، while .

تستمر عمليات TensorFlow مثل tf.cond و tf. tf.while_loop في العمل ، ولكن غالبًا ما يكون التحكم في التدفق أسهل في الكتابة والفهم عند كتابته بلغة Python.

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753]
[0.582645297 0.613145649 0.619306684 0.319202513 0.182036072]
[0.524585426 0.546337605 0.550645113 0.308785647 0.18005164]
[0.481231302 0.497770309 0.501003504 0.299331933 0.178130865]
[0.447229207 0.460361809 0.462906033 0.290701121 0.176270396]
[0.419618756 0.430379033 0.432449728 0.282779962 0.174467146]
[0.396609187 0.405638 0.407366514 0.275476 0.172718227]
[0.377043903 0.384762734 0.386234313 0.268712848 0.17102097]
[0.360137492 0.366836458 0.368109286 0.262426734 0.169372901]
[0.345335096 0.351221472 0.352336824 0.256563932 0.167771652]
[0.332231969 0.337458342 0.338446289 0.251078814 0.166215062]
[0.320524871 0.325206399 0.326089561 0.24593246 0.164701089]
[0.309981436 0.314206958 0.31500268 0.241091311 0.163227797]
[0.300420195 0.304259449 0.304981351 0.236526251 0.161793426]
[0.291697085 0.295205742 0.295864582 0.232211992 0.160396278]
[0.283696055 0.286919087 0.287523568 0.228126258 0.159034774]
[0.276322395 0.279296666 0.27985391 0.224249557 0.157707423]
[0.269497961 0.272254 0.272769839 0.220564634 0.15641281]
[0.263157606 0.265720904 0.266200244 0.21705614 0.155149609]
[0.257246554 0.259638608 0.260085613 0.213710397 0.153916568]
[0.251718313 0.25395745 0.254375577 0.210515186 0.152712509]
[0.246533215 0.248635098 0.249027327 0.207459539 0.151536316]
[0.241657034 0.243635193 0.244004101 0.204533577 0.15038693]
[0.237060249 0.238926381 0.239274174 0.201728329 0.149263337]
[0.232717097 0.234481394 0.234810054 0.199035719 0.148164615]
[0.228605017 0.230276451 0.230587661 0.196448416 0.147089839]
[0.224704206 0.226290658 0.22658591 0.193959698 0.14603813]
[0.220997125 0.222505584 0.222786173 0.191563457 0.145008713]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077],
      dtype=float32)>

إذا كنت مهتمًا ، يمكنك فحص رمز التوقيع الذي يولده.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

الشرطية

سوف يقوم AutoGraph بتحويل بعض عبارات if <condition> إلى مكالمات tf.cond المكافئة. يتم إجراء هذا الاستبدال إذا كان <condition> هو موتر. وبخلاف ذلك ، يتم تنفيذ عبارة if كشرط لبايثون.

يتم تنفيذ شرط Python أثناء التتبع ، لذلك سيتم إضافة فرع واحد بالضبط من الشرط إلى الرسم البياني. بدون AutoGraph ، لن يتمكن هذا الرسم البياني المتعقب من أخذ الفرع البديل إذا كان هناك تدفق تحكم يعتمد على البيانات.

يتتبع tf.cond ويضيف فرعي الشرط إلى الرسم البياني ، مع تحديد فرع ديناميكيًا في وقت التنفيذ. يمكن أن يكون للبحث عن المفقودين آثار جانبية غير مقصودة ؛ تحقق من تأثيرات تتبع الرسم البياني التلقائي لمزيد من المعلومات.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

راجع الوثائق المرجعية للحصول على قيود إضافية على عبارات if المحولة تلقائيًا.

الحلقات

سيقوم AutoGraph بتحويل بعض عبارات for and while إلى عمليات تكرار TensorFlow المكافئة ، مثل tf. tf.while_loop . إذا لم يتم تحويلها ، يتم تنفيذ حلقة for أو while كحلقة Python.

يتم إجراء هذا الاستبدال في الحالات التالية:

  • for x in y : إذا كانت y عبارة عن موتر ، فحول إلى tf.while_loop . في الحالة الخاصة حيث y هي tf.data.Dataset ، يتم إنشاء مجموعة من tf.data.Dataset ops.
  • while <condition> : إذا كان <condition> هو Tensor ، فحول إلى tf. tf.while_loop .

يتم تنفيذ حلقة Python أثناء التتبع ، مما يضيف عمليات إضافية إلى tf.Graph لكل تكرار للحلقة.

تتعقب حلقة TensorFlow جسم الحلقة ، وتختار ديناميكيًا عدد التكرارات التي سيتم تشغيلها في وقت التنفيذ. يظهر جسم الحلقة مرة واحدة فقط في الرسم البياني الذي tf.Graph .

راجع الوثائق المرجعية للحصول على قيود إضافية على كشوفات الرسم البياني AutoGraph المحولة for and while .

التكرار على بيانات بايثون

تتمثل إحدى المشاكل الشائعة في إجراء حلقة حول بيانات Python / tf.function داخل دالة tf. سيتم تنفيذ هذه الحلقة أثناء عملية التتبع ، بإضافة نسخة من النموذج الخاص بك إلى tf.Graph لكل تكرار للحلقة.

إذا كنت ترغب في لف حلقة التدريب بأكملها في tf.function ، فإن الطريقة الأكثر أمانًا للقيام بذلك هي لف بياناتك على هيئة tf.data.Dataset بحيث يقوم AutoGraph بفتح حلقة التدريب ديناميكيًا.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph

عند تغليف بيانات Python / NumPy في مجموعة بيانات ، انتبه إلى tf.data.Dataset.from_generator مقابل tf.data.Dataset.from_tensors . سيحتفظ الأول بالبيانات في Python ويجلبها عبر tf.py_function والتي يمكن أن يكون لها آثار على الأداء ، في حين أن الأخيرة ستجمع نسخة من البيانات tf.constant() واحدة كبيرة في الرسم البياني ، والتي يمكن أن يكون لها آثار على الذاكرة.

تعد قراءة البيانات من الملفات عبر TFRecordDataset و CsvDataset وما إلى ذلك الطريقة الأكثر فاعلية لاستهلاك البيانات ، حيث يمكن لـ TensorFlow نفسها إدارة التحميل غير المتزامن والجلب المسبق للبيانات ، دون الحاجة إلى إشراك Python. لمعرفة المزيد ، راجع tf.data : إنشاء دليل خطوط أنابيب الإدخال TensorFlow .

تراكم القيم في حلقة

النمط الشائع هو تجميع القيم الوسيطة من حلقة. عادة ، يتم تحقيق ذلك عن طريق إلحاق قائمة Python أو إضافة إدخالات إلى قاموس Python. ومع ذلك ، نظرًا لأن هذه آثار جانبية لـ Python ، فلن تعمل كما هو متوقع في حلقة غير متحكم فيها ديناميكيًا. استخدم tf.TensorArray لتجميع النتائج من حلقة غير مرتبطة ديناميكيًا.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216],
        [0.44997275, 1.9107027 , 1.0716251 , 0.717237  ],
        [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]],

       [[0.04946005, 0.69127274, 0.56848884, 0.22406638],
        [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ],
        [0.9178308 , 1.320889  , 0.989761  , 2.0120025 ]]], dtype=float32)>

محددات

تحتوي Function TensorFlow على بعض القيود حسب التصميم التي يجب أن تكون على دراية بها عند تحويل دالة Python إلى Function .

تنفيذ الآثار الجانبية للبايثون

يمكن أن تتصرف التأثيرات الجانبية ، مثل الطباعة ، والإلحاق بالقوائم ، وتحوير الكرة الأرضية ، بشكل غير متوقع داخل إحدى Function ، وأحيانًا يتم تنفيذها مرتين أو لا تنفذ كلها. تحدث فقط في المرة الأولى التي تقوم فيها باستدعاء Function مع مجموعة من المدخلات. بعد ذلك ، يتم إعادة تنفيذ tf.Graph ، دون تنفيذ كود بايثون.

القاعدة العامة هي تجنب الاعتماد على تأثيرات Python الجانبية في منطقك واستخدامهم فقط لتصحيح آثارك. بخلاف ذلك ، تعد واجهات برمجة تطبيقات TensorFlow مثل tf.data و tf.print و tf.summary و tf.Variable.assign و tf.TensorArray هي أفضل طريقة لضمان تنفيذ التعليمات البرمجية الخاصة بك من خلال وقت تشغيل TensorFlow مع كل مكالمة.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

إذا كنت ترغب في تنفيذ كود Python أثناء كل استدعاء Function ، tf.py_function هو فتحة خروج. عيب tf.py_function هو أنها ليست محمولة أو ذات أداء خاص ، ولا يمكن حفظها مع SavedModel ، ولا تعمل بشكل جيد في الإعدادات الموزعة (متعددة GPU ، TPU). أيضًا ، نظرًا لأنه يجب tf.py_function في الرسم البياني ، فإنه يلقي جميع المدخلات / المخرجات إلى الموترات.

تغيير متغيرات بايثون العالمية والحرة

يعد تغيير متغيرات Python العامة والحرة أحد الآثار الجانبية لـ Python ، لذلك يحدث فقط أثناء التتبع.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

في بعض الأحيان يصعب ملاحظة السلوكيات غير المتوقعة. في المثال أدناه ، يهدف counter إلى حماية زيادة المتغير. ومع ذلك ، نظرًا لأنه عدد صحيح من نوع python وليس كائن TensorFlow ، يتم التقاط قيمته أثناء التتبع الأول. عند استخدام دالة tf ، سيتم تسجيل assign_add tf.function قيد أو شرط في الرسم البياني الأساسي. لذلك ستزيد v بمقدار 1 ، في كل مرة يتم استدعاء tf.function . هذه المشكلة شائعة بين المستخدمين الذين يحاولون ترحيل رمز Tensorflow في وضع Grpah إلى Tensorflow 2 باستخدام tf.function ، عند استخدام التأثيرات الجانبية للبيثون ( counter في المثال) لتحديد العمليات التي سيتم تشغيلها ( assign_add في المثال ). عادة ، لا يدرك المستخدمون ذلك إلا بعد رؤية نتائج رقمية مشبوهة ، أو أداء أقل بكثير من المتوقع (على سبيل المثال ، إذا كانت العملية الخاضعة للحراسة مكلفة للغاية).

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

الحل البديل لتحقيق السلوك المتوقع هو استخدام tf.init_scope لرفع العمليات خارج الرسم البياني للدالة. هذا يضمن أن الزيادة المتغيرة تتم مرة واحدة فقط أثناء وقت التتبع. وتجدر الإشارة إلى أن init_scope لها آثار جانبية أخرى بما في ذلك تدفق التحكم الواضح وشريط التدرج. في بعض الأحيان ، قد يصبح استخدام init_scope معقدًا جدًا بحيث لا يمكن إدارته بواقعية.

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

باختصار ، كقاعدة عامة ، يجب أن تتجنب تحوير كائنات الثعبان مثل الأعداد الصحيحة أو الحاويات مثل القوائم التي تعيش خارج Function . بدلاً من ذلك ، استخدم الحجج وكائنات TF. على سبيل المثال ، يحتوي قسم "تجميع القيم في حلقة" على مثال واحد لكيفية تنفيذ عمليات تشبه القائمة.

يمكنك ، في بعض الحالات ، التقاط الحالة ومعالجتها إذا كانت tf.Variable . هذه هي الطريقة التي يتم بها تحديث أوزان نماذج Keras باستدعاءات متكررة لنفس وظيفة ConcreteFunction .

استخدام مكررات ومولدات بايثون

تعتمد العديد من ميزات Python ، مثل المولدات والمكررات ، على وقت تشغيل Python لتتبع الحالة. بشكل عام ، بينما تعمل هذه التركيبات كما هو متوقع في الوضع الحثيث ، فهي أمثلة على الآثار الجانبية لبايثون وبالتالي تحدث فقط أثناء التتبع.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

تمامًا مثل كيفية امتلاك TensorFlow لـ tf.TensorArray المتخصصة لبنى القائمة ، فإنه يحتوي على tf.data.Iterator متخصص في إنشاءات التكرار. راجع قسم تحويلات AutoGraph للحصول على نظرة عامة. أيضًا ، يمكن أن تساعد واجهة برمجة تطبيقات tf.data في تنفيذ أنماط المولد:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

يجب أن تكون جميع مخرجات دالة tf قيمًا مرجعة

باستثناء tf.Variable s ، يجب أن تعيد دالة tf جميع مخرجاتها. محاولة الوصول مباشرة إلى أي موتر من دالة دون المرور بقيم الإرجاع يؤدي إلى حدوث "تسريبات".

على سبيل المثال ، الوظيفة الموجودة أسفل "تسرب" الموتر a عبر Python global x :

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'Tensor' object has no attribute 'numpy'

هذا صحيح حتى إذا تم إرجاع القيمة المسربة أيضًا:

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'Tensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
>>>  File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main
>>>  File /usr/lib/python3.7/runpy.py, line 85, in _run_code
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once
>>>  File /usr/lib/python3.7/asyncio/events.py, line 88, in _run
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code
>>>  File /tmp/ipykernel_26244/566849597.py, line 7, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler
>>>  File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__

Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

عادةً ما تحدث مثل هذه التسريبات عند استخدام عبارات Python أو هياكل البيانات. بالإضافة إلى تسريب الموترات التي يتعذر الوصول إليها ، من المحتمل أيضًا أن تكون هذه العبارات خاطئة لأنها تعتبر آثارًا جانبية لـ Python ، ولا يتم ضمان تنفيذها عند كل استدعاء دالة.

تتضمن الطرق الشائعة لتسريب الموترات المحلية أيضًا تحوير مجموعة Python خارجية ، أو كائن:

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

لا يتم دعم وظائف tf العودية

لا يتم دعم Function التكرارية ويمكن أن تسبب حلقات لا نهائية. فمثلا،

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__
        return _abc_instancecheck(cls, instance)

    RecursionError: maximum recursion depth exceeded while calling a Python object

حتى إذا بدت Function العودية تعمل ، فسيتم تتبع وظيفة بيثون عدة مرات وقد يكون لها تأثير ضمني على الأداء. فمثلا،

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

مشاكل معروفة

إذا لم يتم تقييم Function بشكل صحيح ، فقد يتم تفسير الخطأ من خلال هذه المشكلات المعروفة التي تم التخطيط لإصلاحها في المستقبل.

اعتمادًا على متغيرات Python العامة والحرة

تقوم Function بإنشاء وظيفة ConcreteFunction جديدة عند استدعائها بقيمة جديدة من وسيطة Python. ومع ذلك ، فإنه لا يفعل ذلك لإغلاق Python أو globals أو nonlocals لتلك Function . إذا تغيرت قيمتها بين المكالمات إلى Function ، ستستمر Function في استخدام القيم التي كانت لديها عندما تم تتبعها. هذا يختلف عن كيفية عمل وظائف Python العادية.

لهذا السبب ، يجب عليك اتباع أسلوب برمجة وظيفي يستخدم الوسائط بدلاً من إغلاق الأسماء الخارجية.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

هناك طريقة أخرى لتحديث قيمة عامة ، وهي جعلها متغير tf.Variable واستخدام طريقة Variable.assign بدلاً من ذلك.

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

بالاعتماد على كائنات بايثون

التوصية لتمرير كائنات Python كوسائط إلى tf.function لديها عدد من المشكلات المعروفة ، والتي من المتوقع إصلاحها في المستقبل. بشكل عام ، يمكنك الاعتماد على التتبع المتسق إذا كنت تستخدم بنية Python بدائية أو بنية متوافقة مع tf.nest كوسيطة أو تمرير مثيل مختلف من كائن إلى Function . ومع ذلك ، لن تقوم Function بإنشاء تتبع جديد عندما تمرر نفس الكائن وتغير سماته فقط .

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

استخدام نفس Function لتقييم المثيل المحدث للنموذج سيكون عربات التي تجرها الدواب لأن النموذج المحدث له نفس مفتاح التخزين المؤقت مثل النموذج الأصلي.

لهذا السبب ، يوصى بكتابة Function لتجنب الاعتماد على سمات الكائن القابلة للتغيير أو إنشاء كائنات جديدة.

إذا لم يكن ذلك ممكنًا ، فإن أحد الحلول هو إنشاء Function جديدة في كل مرة تقوم فيها بتعديل الكائن لفرض الاستعادة:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

نظرًا لأن الاسترداد يمكن أن يكون مكلفًا ، يمكنك استخدام tf.Variable s كسمات كائن ، والتي يمكن تغييرها (ولكن لا يتم تغييرها ، بحذر!) للحصول على تأثير مماثل دون الحاجة إلى التصحيح.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

إنشاء متغيرات tf

تدعم Function فقط tf.Variable التي تم إنشاؤها مرة واحدة في المكالمة الأولى ، وإعادة استخدامها عبر استدعاءات الوظائف اللاحقة. سينشئ مقتطف الشفرة أدناه tf.Variable جديدًا في كل استدعاء دالة ، مما ينتج عنه استثناء ValueError .

مثال:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

النمط الشائع المستخدم للتغلب على هذا القيد هو البدء بقيمة Python None ، ثم إنشاء tf.Variable إذا كانت القيمة لا شيء:

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

باستخدام مع محسنات Keras متعددة

قد تواجه ValueError: tf.function only supports singleton tf.Variables created on the first call. عند استخدام أكثر من مُحسِّن tf.function مع وظيفة tf. يحدث هذا الخطأ لأن المحسّنين ينشئون داخليًا tf.Variables عند تطبيق التدرجات لأول مرة.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients  **
        self._create_all_weights(var_list)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights
        _ = self.iterations
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight
        aggregation=aggregation)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable
        shape=variable_shape if variable_shape else None)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

إذا كنت بحاجة إلى تغيير المُحسِّن أثناء التدريب ، فإن الحل البديل هو إنشاء Function جديدة لكل مُحسِّن ، واستدعاء الوظيفة ConcreteFunction مباشرةً.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

تستخدم مع نماذج Keras متعددة

قد تواجه أيضًا ValueError: tf.function only supports singleton tf.Variables created on the first call. عند تمرير مثيلات نموذج مختلفة إلى نفس Function .

يحدث هذا الخطأ لأن نماذج Keras (التي لم يتم تحديد شكل إدخالها ) وطبقات Keras تنشئ tf.Variables s عندما يتم استدعاؤها لأول مرة. ربما تحاول تهيئة هذه المتغيرات داخل Function تم استدعاؤها بالفعل. لتجنب هذا الخطأ ، حاول استدعاء model.build(input_shape) لتهيئة كل الأوزان قبل تدريب النموذج.

قراءة متعمقة

للتعرف على كيفية تصدير Function وتحميلها ، راجع دليل SavedModel . لمعرفة المزيد حول تحسينات الرسم البياني التي يتم إجراؤها بعد التتبع ، راجع دليل Grappler . لمعرفة كيفية تحسين خط أنابيب البيانات الخاص بك وملف تعريف النموذج الخاص بك ، راجع دليل منشئ ملفات التعريف .