عرض على TensorFlow.org | تشغيل في Google Colab | عرض المصدر على جيثب | تحميل دفتر |
في TensorFlow 2 ، يتم تشغيل التنفيذ الحثيث افتراضيًا. واجهة المستخدم سهلة الاستخدام ومرنة (تشغيل العمليات لمرة واحدة أسهل بكثير وأسرع) ، ولكن هذا يمكن أن يأتي على حساب الأداء وقابلية النشر.
يمكنك استخدام tf.function
لعمل رسوم بيانية من برامجك. إنها أداة تحويل تقوم بإنشاء رسوم بيانية لتدفق البيانات مستقلة عن Python من كود Python الخاص بك. سيساعدك هذا في إنشاء نماذج عالية الأداء ومحمولة ، ويلزم استخدام SavedModel
.
سيساعدك هذا الدليل على تصور كيفية عمل tf.function
تحت الغطاء ، حتى تتمكن من استخدامها بفعالية.
النقاط الرئيسية والتوصيات هي:
- تصحيح الأخطاء في الوضع الحثيث ، ثم تزيينها باستخدام
@tf.function
. - لا تعتمد على تأثيرات Python الجانبية مثل تحور الكائن أو إلحاق القائمة.
- تعمل
tf.function
بشكل أفضل مع عمليات TensorFlow ؛ يتم تحويل استدعاءات NumPy و Python إلى ثوابت.
يثبت
import tensorflow as tf
حدد وظيفة مساعدة لتوضيح أنواع الأخطاء التي قد تواجهها:
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
الأساسيات
إستعمال
Function
التي تحددها (على سبيل المثال من خلال تطبيق decorator @tf.function
) تشبه تمامًا عملية TensorFlow الأساسية: يمكنك تنفيذها بفارغ الصبر ؛ يمكنك حساب التدرجات. وما إلى ذلك وهلم جرا.
@tf.function # The decorator converts `add` into a `Function`.
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy= array([[2., 2.], [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
يمكنك استخدام Function
داخل Function
أخرى.
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy= array([[3., 3.], [3., 3.], [3., 3.]], dtype=float32)>
يمكن أن تكون Function
أسرع من التعليمات البرمجية المتحمسة ، خاصة بالنسبة للرسوم البيانية التي تحتوي على العديد من العمليات الصغيرة. ولكن بالنسبة للرسوم البيانية التي تحتوي على عدد قليل من العمليات باهظة الثمن (مثل التلافيف) ، فقد لا ترى الكثير من التسريع.
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735 Function conv: 0.005791576000774512 Note how there's not much difference in performance for convolutions
اقتفاء أثر
يشرح هذا القسم كيفية Function
تحت الغطاء ، بما في ذلك تفاصيل التنفيذ التي قد تتغير في المستقبل . ومع ذلك ، بمجرد فهمك لماذا ومتى يحدث التتبع ، يصبح استخدام tf.function
بشكل فعال أسهل بكثير!
ما هو "التعقب"؟
تقوم إحدى Function
بتشغيل برنامجك في رسم بياني TensorFlow . ومع ذلك ، لا يمكن tf.Graph
أن يمثل كل الأشياء التي تكتبها في برنامج TensorFlow الشغوف. على سبيل المثال ، تدعم Python تعدد الأشكال ، لكن tf.Graph
تتطلب مدخلاتها أن يكون لها نوع وأبعاد بيانات محددة. أو يمكنك أداء مهام جانبية مثل قراءة وسيطات سطر الأوامر ، أو إثارة خطأ ، أو العمل باستخدام كائن Python أكثر تعقيدًا ؛ لا يمكن تشغيل أي من هذه الأشياء في رسم tf.Graph
.
تعمل Function
على سد هذه الفجوة عن طريق فصل الكود الخاص بك على مرحلتين:
1) في المرحلة الأولى ، يشار إليها باسم " التتبع " ، تقوم Function
بإنشاء tf.Graph
جديد. يعمل كود Python بشكل طبيعي ، ولكن يتم تأجيل جميع عمليات TensorFlow (مثل إضافة اثنين من Tensors ): يتم التقاطها بواسطة tf.Graph
ولا يتم تشغيلها.
2) في المرحلة الثانية ، يتم تشغيل رسم tf.Graph
يحتوي على كل ما تم تأجيله في المرحلة الأولى. هذه المرحلة أسرع بكثير من مرحلة التعقب.
اعتمادًا على مدخلاتها ، لن تعمل Function
دائمًا في المرحلة الأولى عندما يتم استدعاؤها. راجع "قواعد التتبع" أدناه للحصول على فكرة أفضل عن كيفية اتخاذ هذا القرار. تخطي المرحلة الأولى وتنفيذ المرحلة الثانية فقط هو ما يمنحك أداء TensorFlow العالي.
عندما تقرر Function
التتبع ، تتبع مرحلة التتبع مباشرة المرحلة الثانية ، لذا فإن استدعاء Function
يؤدي إلى إنشاء tf.Graph
. سترى لاحقًا كيف يمكنك تشغيل مرحلة التتبع فقط باستخدام get_concrete_function
.
عند تمرير وسيطات من أنواع مختلفة إلى Function
، يتم تشغيل كلا المرحلتين:
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32) Tracing with Tensor("a:0", shape=(), dtype=float32) tf.Tensor(2.2, shape=(), dtype=float32) Tracing with Tensor("a:0", shape=(), dtype=string) tf.Tensor(b'aa', shape=(), dtype=string)
لاحظ أنه إذا قمت باستدعاء Function
بشكل متكرر بنفس نوع الوسيطة ، فسيتخطى TensorFlow مرحلة التتبع ويعيد استخدام الرسم البياني الذي تم تتبعه مسبقًا ، حيث سيكون الرسم البياني الذي تم إنشاؤه متطابقًا.
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)
يمكنك استخدام pretty_printed_concrete_signatures()
لمشاهدة جميع الآثار المتاحة:
print(double.pretty_printed_concrete_signatures())
double(a) Args: a: int32 Tensor, shape=() Returns: int32 Tensor, shape=() double(a) Args: a: float32 Tensor, shape=() Returns: float32 Tensor, shape=() double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
حتى الآن ، رأيت أن tf.function
تُنشئ طبقة إرسال ديناميكية مخزنة مؤقتًا عبر منطق تتبع الرسم البياني في TensorFlow. لتكون أكثر تحديدًا حول المصطلحات:
-
tf.Graph
هو التمثيل الخام غير الملائم للغة والمحمول لحساب TensorFlow. - تقوم وظيفة
ConcreteFunction
بلف رسمtf.Graph
. - تقوم إحدى
Function
بإدارة ذاكرة التخزين المؤقت للوظائفConcreteFunction
وتختار العنصر المناسب لإدخالاتك. - تلف دالة
tf.function
دالة بايثون ، وتعيد كائنFunction
. - ينشئ التتبع رسمًا
tf.Graph
. ويلفه في دالةConcreteFunction
، تُعرف أيضًا باسم التتبع.
قواعد التعقب
تحدد Function
ما إذا كان سيتم إعادة استخدام دالة ConcreteFunction
التي تم تتبعها عن طريق حساب مفتاح ذاكرة التخزين المؤقت من args و kwargs للإدخال. مفتاح ذاكرة التخزين المؤقت هو مفتاح يحدد وظيفة ConcreteFunction
بناءً على args and kwargs لاستدعاء Function
، وفقًا للقواعد التالية (والتي قد تتغير):
- المفتاح الذي تم إنشاؤه لـ
tf.Tensor
هو شكله ونوعه. - المفتاح الذي تم إنشاؤه
tf.Variable
هو معرف متغير فريد. - المفتاح الذي يتم إنشاؤه لبايثون بدائي (مثل
int
،float
،str
) هو قيمته. - المفتاح الذي تم إنشاؤه من أجل
dict
s وlist
s وtuple
s وnamedtuple
s وattr
s هو المجموعة المسطحة لمفاتيح الأوراق (انظرnest.flatten
). (نتيجة لهذا التسطيح ، فإن استدعاء دالة محددة بهيكل متداخل مختلف عن تلك المستخدمة أثناء التتبع سيؤدي إلى خطأ في النوع). - بالنسبة لجميع أنواع Python الأخرى ، يكون المفتاح فريدًا للكائن. بهذه الطريقة يتم تتبع وظيفة أو طريقة بشكل مستقل لكل حالة يتم استدعاؤها معها.
السيطرة على الاسترداد
يساعد الاسترداد ، وهو عندما تقوم وظيفتك بإنشاء أكثر من تتبع ، على ضمان أن Function
يقوم بإنشاء الرسوم البيانية الصحيحة لكل مجموعة من المدخلات. ومع ذلك ، فإن البحث عن المفقودين عملية مكلفة! إذا أعادت وظيفتك تتبع رسم بياني جديد لكل مكالمة ، فستجد أن شفرتك يتم تنفيذها بشكل أبطأ مما لو لم تستخدم Function
tf.function
.
للتحكم في سلوك التتبع ، يمكنك استخدام الأساليب التالية:
- حدد
input_signature
فيtf.function
للحد من التتبع.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([4 1], shape=(2,), dtype=int32) Caught expected exception <class 'ValueError'>: Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module> next_collatz(tf.constant([[1, 2], [3, 4]])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor( [[1 2] [3 4]], shape=(2, 2), dtype=int32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)). Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module> next_collatz(tf.constant([1.0, 2.0])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor([1. 2.], shape=(2,), dtype=float32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
حدد بُعدًا [بلا] في
tf.TensorSpec
للسماح بالمرونة في إعادة استخدام التتبع.نظرًا لأن TensorFlow يطابق الموترات بناءً على شكلها ، فإن استخدام بُعد
None
كحرف بدل سيسمح لـFunction
s بإعادة استخدام الآثار لمدخلات متغيرة الحجم. يمكن أن يحدث إدخال متغير الحجم إذا كان لديك تسلسلات ذات أطوال مختلفة ، أو صور بأحجام مختلفة لكل دفعة (انظر دروس Transformer و Deep Dream على سبيل المثال).
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([1 2 3], shape=(3,), dtype=int32) tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
حجج بايثون إلى Tensors لتقليل الاسترداد.
في كثير من الأحيان ، يتم استخدام وسيطات Python للتحكم في المعلمات الفائقة وإنشاءات الرسم البياني - على سبيل المثال ،
num_layers=10
أوtraining=True
أوnonlinearity='relu'
. لذا ، إذا تغيرت حجة بايثون ، فمن المنطقي أن تضطر إلى إعادة تتبع الرسم البياني.ومع ذلك ، من الممكن ألا يتم استخدام وسيطة بايثون للتحكم في إنشاء الرسم البياني. في هذه الحالات ، يمكن للتغيير في قيمة Python أن يؤدي إلى ارتداد غير ضروري. خذ ، على سبيل المثال ، حلقة التدريب هذه ، والتي سيقوم AutoGraph بفكها ديناميكيًا. على الرغم من الآثار المتعددة ، فإن الرسم البياني الذي تم إنشاؤه هو في الواقع متطابق ، لذا فإن التصحيح غير ضروري.
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
for _ in tf.range(num_steps):
train_one_step()
print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)
print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments. Tracing with num_steps = 10 Executing with num_steps = 10 Tracing with num_steps = 20 Executing with num_steps = 20 Traces are reused for Tensor arguments. Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32) Executing with num_steps = 10 Executing with num_steps = 20
إذا كنت بحاجة إلى فرض الاسترداد ، فأنشئ Function
جديدة. كائنات Function
المنفصلة مضمونة لعدم مشاركة الآثار.
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Tracing! Executing Tracing! Executing
الحصول على وظائف محددة
في كل مرة يتم فيها تتبع دالة ، يتم إنشاء وظيفة ملموسة جديدة. يمكنك الحصول على دالة ملموسة مباشرة باستخدام get_concrete_function
.
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace Executing traced function tf.Tensor(b'aa', shape=(), dtype=string) tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)
تعرض طباعة ConcreteFunction
ملخصًا لوسائط الإدخال (مع الأنواع) ونوع الإخراج الخاص بها.
print(double_strings)
ConcreteFunction double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
يمكنك أيضًا استرداد توقيع الوظيفة الملموسة مباشرةً.
print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {}) Tensor("Identity:0", shape=(), dtype=string)
سيؤدي استخدام تتبع ملموس مع أنواع غير متوافقة إلى حدوث خطأ
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
Caught expected exception <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module> double_strings(tf.constant(1)) tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]
قد تلاحظ أن حجج بايثون يتم التعامل معها بشكل خاص في توقيع إدخال الوظيفة الملموسة. قبل TensorFlow 2.3 ، تمت إزالة حجج Python ببساطة من توقيع الوظيفة الملموسة. بدءًا من TensorFlow 2.3 ، تظل وسيطات Python في التوقيع ، ولكنها مقيدة بأخذ القيمة المحددة أثناء التتبع.
@tf.function
def pow(a, b):
return a ** b
square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2) Args: a: float32 Tensor, shape=<unknown> Returns: float32 Tensor, shape=<unknown>l10n-placeholder37 l10n-placeholder38l10n-placeholder35 l10n-placeholder36
assert square(tf.constant(10.0)) == 100
with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl cancellation_manager) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature raise TypeError(f"{self._flat_signature_summary()} got unexpected " TypeError: pow(a) got unexpected keyword arguments: b. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module> square(tf.constant(10.0), b=3) TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.
الحصول على الرسوم البيانية
كل وظيفة ملموسة عبارة عن غلاف قابل للاستدعاء حول رسم tf.Graph
. على الرغم من أن استرداد كائن tf.Graph
الفعلي ليس شيئًا ستحتاج إلى القيام به عادةً ، إلا أنه يمكنك الحصول عليه بسهولة من أي وظيفة ملموسة.
graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
[] -> a ['a', 'a'] -> add ['add'] -> Identity
تصحيح
بشكل عام ، يعد تصحيح الأخطاء في وضع التشغيل أسهل من داخل tf.function
. يجب عليك التأكد من أن التعليمات البرمجية الخاصة بك تنفذ خالية من الأخطاء في الوضع الحثيث قبل التزيين tf.function
. للمساعدة في عملية التصحيح ، يمكنك استدعاء tf.config.run_functions_eagerly(True)
لتعطيل tf.function
وإعادة تمكينها بشكل عام.
عند تعقب المشكلات التي تظهر فقط داخل tf.function
، إليك بعض النصائح:
- لا يتم تنفيذ مكالمات
print
Python القديمة البسيطة إلا أثناء التتبع ، مما يساعدك على تعقب (إعادة) تتبع وظيفتك. - سيتم تنفيذ مكالمات
tf.print
في كل مرة ، ويمكن أن تساعدك على تعقب القيم الوسيطة أثناء التنفيذ. -
tf.debugging.enable_check_numerics
هي طريقة سهلة لتعقب مكان إنشاء NaNs و Inf. - يمكن أن يساعدك
pdb
( مصحح أخطاء Python ) في فهم ما يحدث أثناء التتبع. (تحذير: سوف ينقلكpdb
إلى كود المصدر المحول AutoGraph.)
تحويلات AutoGraph
AutoGraph هي مكتبة تعمل بشكل افتراضي في tf.function
، وتحول مجموعة فرعية من كود Python المتهور إلى عمليات TensorFlow المتوافقة مع الرسم البياني. يتضمن ذلك التحكم في التدفق مثل if
، for
، while
.
تستمر عمليات TensorFlow مثل tf.cond
و tf. tf.while_loop
في العمل ، ولكن غالبًا ما يكون التحكم في التدفق أسهل في الكتابة والفهم عند كتابته بلغة Python.
# A simple loop
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753] [0.582645297 0.613145649 0.619306684 0.319202513 0.182036072] [0.524585426 0.546337605 0.550645113 0.308785647 0.18005164] [0.481231302 0.497770309 0.501003504 0.299331933 0.178130865] [0.447229207 0.460361809 0.462906033 0.290701121 0.176270396] [0.419618756 0.430379033 0.432449728 0.282779962 0.174467146] [0.396609187 0.405638 0.407366514 0.275476 0.172718227] [0.377043903 0.384762734 0.386234313 0.268712848 0.17102097] [0.360137492 0.366836458 0.368109286 0.262426734 0.169372901] [0.345335096 0.351221472 0.352336824 0.256563932 0.167771652] [0.332231969 0.337458342 0.338446289 0.251078814 0.166215062] [0.320524871 0.325206399 0.326089561 0.24593246 0.164701089] [0.309981436 0.314206958 0.31500268 0.241091311 0.163227797] [0.300420195 0.304259449 0.304981351 0.236526251 0.161793426] [0.291697085 0.295205742 0.295864582 0.232211992 0.160396278] [0.283696055 0.286919087 0.287523568 0.228126258 0.159034774] [0.276322395 0.279296666 0.27985391 0.224249557 0.157707423] [0.269497961 0.272254 0.272769839 0.220564634 0.15641281] [0.263157606 0.265720904 0.266200244 0.21705614 0.155149609] [0.257246554 0.259638608 0.260085613 0.213710397 0.153916568] [0.251718313 0.25395745 0.254375577 0.210515186 0.152712509] [0.246533215 0.248635098 0.249027327 0.207459539 0.151536316] [0.241657034 0.243635193 0.244004101 0.204533577 0.15038693] [0.237060249 0.238926381 0.239274174 0.201728329 0.149263337] [0.232717097 0.234481394 0.234810054 0.199035719 0.148164615] [0.228605017 0.230276451 0.230587661 0.196448416 0.147089839] [0.224704206 0.226290658 0.22658591 0.193959698 0.14603813] [0.220997125 0.222505584 0.222786173 0.191563457 0.145008713] <tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077], dtype=float32)>
إذا كنت مهتمًا ، يمكنك فحص رمز التوقيع الذي يولده.
print(tf.autograph.to_code(f.python_function))
def tf__f(x): with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope: do_return = False retval_ = ag__.UndefinedReturnValue() def get_state(): return (x,) def set_state(vars_): nonlocal x (x,) = vars_ def loop_body(): nonlocal x ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope) x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope) def loop_test(): return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1) ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {}) try: do_return = True retval_ = ag__.ld(x) except: do_return = False raise return fscope.ret(retval_, do_return)
الشرطية
سوف يقوم AutoGraph بتحويل بعض عبارات if <condition>
إلى مكالمات tf.cond
المكافئة. يتم إجراء هذا الاستبدال إذا كان <condition>
هو موتر. وبخلاف ذلك ، يتم تنفيذ عبارة if
كشرط لبايثون.
يتم تنفيذ شرط Python أثناء التتبع ، لذلك سيتم إضافة فرع واحد بالضبط من الشرط إلى الرسم البياني. بدون AutoGraph ، لن يتمكن هذا الرسم البياني المتعقب من أخذ الفرع البديل إذا كان هناك تدفق تحكم يعتمد على البيانات.
يتتبع tf.cond
ويضيف فرعي الشرط إلى الرسم البياني ، مع تحديد فرع ديناميكيًا في وقت التنفيذ. يمكن أن يكون للبحث عن المفقودين آثار جانبية غير مقصودة ؛ تحقق من تأثيرات تتبع الرسم البياني التلقائي لمزيد من المعلومات.
@tf.function
def fizzbuzz(n):
for i in tf.range(1, n + 1):
print('Tracing for loop')
if i % 15 == 0:
print('Tracing fizzbuzz branch')
tf.print('fizzbuzz')
elif i % 3 == 0:
print('Tracing fizz branch')
tf.print('fizz')
elif i % 5 == 0:
print('Tracing buzz branch')
tf.print('buzz')
else:
print('Tracing default branch')
tf.print(i)
fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop Tracing fizzbuzz branch Tracing fizz branch Tracing buzz branch Tracing default branch 1 2 fizz 4 buzz 1 2 fizz 4 buzz fizz 7 8 fizz buzz 11 fizz 13 14 fizzbuzz 16 17 fizz 19 buzz
راجع الوثائق المرجعية للحصول على قيود إضافية على عبارات if المحولة تلقائيًا.
الحلقات
سيقوم AutoGraph بتحويل بعض عبارات for
and while
إلى عمليات تكرار TensorFlow المكافئة ، مثل tf. tf.while_loop
. إذا لم يتم تحويلها ، يتم تنفيذ حلقة for
أو while
كحلقة Python.
يتم إجراء هذا الاستبدال في الحالات التالية:
-
for x in y
: إذا كانتy
عبارة عن موتر ، فحول إلىtf.while_loop
. في الحالة الخاصة حيثy
هيtf.data.Dataset
، يتم إنشاء مجموعة منtf.data.Dataset
ops. -
while <condition>
: إذا كان<condition>
هو Tensor ، فحول إلى tf.tf.while_loop
.
يتم تنفيذ حلقة Python أثناء التتبع ، مما يضيف عمليات إضافية إلى tf.Graph
لكل تكرار للحلقة.
تتعقب حلقة TensorFlow جسم الحلقة ، وتختار ديناميكيًا عدد التكرارات التي سيتم تشغيلها في وقت التنفيذ. يظهر جسم الحلقة مرة واحدة فقط في الرسم البياني الذي tf.Graph
.
راجع الوثائق المرجعية للحصول على قيود إضافية على كشوفات الرسم البياني AutoGraph المحولة for
and while
.
التكرار على بيانات بايثون
تتمثل إحدى المشاكل الشائعة في إجراء حلقة حول بيانات Python / tf.function
داخل دالة tf. سيتم تنفيذ هذه الحلقة أثناء عملية التتبع ، بإضافة نسخة من النموذج الخاص بك إلى tf.Graph
لكل تكرار للحلقة.
إذا كنت ترغب في لف حلقة التدريب بأكملها في tf.function
، فإن الطريقة الأكثر أمانًا للقيام بذلك هي لف بياناتك على هيئة tf.data.Dataset
بحيث يقوم AutoGraph بفتح حلقة التدريب ديناميكيًا.
def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) contains {} nodes in its graph".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
عند تغليف بيانات Python / NumPy في مجموعة بيانات ، انتبه إلى tf.data.Dataset.from_generator
مقابل tf.data.Dataset.from_tensors
. سيحتفظ الأول بالبيانات في Python ويجلبها عبر tf.py_function
والتي يمكن أن يكون لها آثار على الأداء ، في حين أن الأخيرة ستجمع نسخة من البيانات tf.constant()
واحدة كبيرة في الرسم البياني ، والتي يمكن أن يكون لها آثار على الذاكرة.
تعد قراءة البيانات من الملفات عبر TFRecordDataset
و CsvDataset
وما إلى ذلك الطريقة الأكثر فاعلية لاستهلاك البيانات ، حيث يمكن لـ TensorFlow نفسها إدارة التحميل غير المتزامن والجلب المسبق للبيانات ، دون الحاجة إلى إشراك Python. لمعرفة المزيد ، راجع tf.data
: إنشاء دليل خطوط أنابيب الإدخال TensorFlow .
تراكم القيم في حلقة
النمط الشائع هو تجميع القيم الوسيطة من حلقة. عادة ، يتم تحقيق ذلك عن طريق إلحاق قائمة Python أو إضافة إدخالات إلى قاموس Python. ومع ذلك ، نظرًا لأن هذه آثار جانبية لـ Python ، فلن تعمل كما هو متوقع في حلقة غير متحكم فيها ديناميكيًا. استخدم tf.TensorArray
لتجميع النتائج من حلقة غير مرتبطة ديناميكيًا.
batch_size = 2
seq_len = 3
feature_size = 4
def rnn_step(inp, state):
return inp + state
@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
max_seq_len = input_data.shape[0]
states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
return tf.transpose(states.stack(), [1, 0, 2])
dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy= array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216], [0.44997275, 1.9107027 , 1.0716251 , 0.717237 ], [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]], [[0.04946005, 0.69127274, 0.56848884, 0.22406638], [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ], [0.9178308 , 1.320889 , 0.989761 , 2.0120025 ]]], dtype=float32)>
محددات
تحتوي Function
TensorFlow على بعض القيود حسب التصميم التي يجب أن تكون على دراية بها عند تحويل دالة Python إلى Function
.
تنفيذ الآثار الجانبية للبايثون
يمكن أن تتصرف التأثيرات الجانبية ، مثل الطباعة ، والإلحاق بالقوائم ، وتحوير الكرة الأرضية ، بشكل غير متوقع داخل إحدى Function
، وأحيانًا يتم تنفيذها مرتين أو لا تنفذ كلها. تحدث فقط في المرة الأولى التي تقوم فيها باستدعاء Function
مع مجموعة من المدخلات. بعد ذلك ، يتم إعادة تنفيذ tf.Graph
، دون تنفيذ كود بايثون.
القاعدة العامة هي تجنب الاعتماد على تأثيرات Python الجانبية في منطقك واستخدامهم فقط لتصحيح آثارك. بخلاف ذلك ، تعد واجهات برمجة تطبيقات TensorFlow مثل tf.data
و tf.print
و tf.summary
و tf.Variable.assign
و tf.TensorArray
هي أفضل طريقة لضمان تنفيذ التعليمات البرمجية الخاصة بك من خلال وقت تشغيل TensorFlow مع كل مكالمة.
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Traced with 1 Executed with 1 Executed with 1 Traced with 2 Executed with 2
إذا كنت ترغب في تنفيذ كود Python أثناء كل استدعاء Function
، tf.py_function
هو فتحة خروج. عيب tf.py_function
هو أنها ليست محمولة أو ذات أداء خاص ، ولا يمكن حفظها مع SavedModel ، ولا تعمل بشكل جيد في الإعدادات الموزعة (متعددة GPU ، TPU). أيضًا ، نظرًا لأنه يجب tf.py_function
في الرسم البياني ، فإنه يلقي جميع المدخلات / المخرجات إلى الموترات.
تغيير متغيرات بايثون العالمية والحرة
يعد تغيير متغيرات Python العامة والحرة أحد الآثار الجانبية لـ Python ، لذلك يحدث فقط أثناء التتبع.
external_list = []
@tf.function
def side_effect(x):
print('Python side effect')
external_list.append(x)
side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect
في بعض الأحيان يصعب ملاحظة السلوكيات غير المتوقعة. في المثال أدناه ، يهدف counter
إلى حماية زيادة المتغير. ومع ذلك ، نظرًا لأنه عدد صحيح من نوع python وليس كائن TensorFlow ، يتم التقاط قيمته أثناء التتبع الأول. عند استخدام دالة tf ، سيتم تسجيل assign_add
tf.function
قيد أو شرط في الرسم البياني الأساسي. لذلك ستزيد v
بمقدار 1 ، في كل مرة يتم استدعاء tf.function
. هذه المشكلة شائعة بين المستخدمين الذين يحاولون ترحيل رمز Tensorflow في وضع Grpah إلى Tensorflow 2 باستخدام tf.function
، عند استخدام التأثيرات الجانبية للبيثون ( counter
في المثال) لتحديد العمليات التي سيتم تشغيلها ( assign_add
في المثال ). عادة ، لا يدرك المستخدمون ذلك إلا بعد رؤية نتائج رقمية مشبوهة ، أو أداء أقل بكثير من المتوقع (على سبيل المثال ، إذا كانت العملية الخاضعة للحراسة مكلفة للغاية).
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# A python side-effect
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 2, 3
1 2 3
الحل البديل لتحقيق السلوك المتوقع هو استخدام tf.init_scope
لرفع العمليات خارج الرسم البياني للدالة. هذا يضمن أن الزيادة المتغيرة تتم مرة واحدة فقط أثناء وقت التتبع. وتجدر الإشارة إلى أن init_scope
لها آثار جانبية أخرى بما في ذلك تدفق التحكم الواضح وشريط التدرج. في بعض الأحيان ، قد يصبح استخدام init_scope
معقدًا جدًا بحيث لا يمكن إدارته بواقعية.
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 1, 1
1 1 1
باختصار ، كقاعدة عامة ، يجب أن تتجنب تحوير كائنات الثعبان مثل الأعداد الصحيحة أو الحاويات مثل القوائم التي تعيش خارج Function
. بدلاً من ذلك ، استخدم الحجج وكائنات TF. على سبيل المثال ، يحتوي قسم "تجميع القيم في حلقة" على مثال واحد لكيفية تنفيذ عمليات تشبه القائمة.
يمكنك ، في بعض الحالات ، التقاط الحالة ومعالجتها إذا كانت tf.Variable
. هذه هي الطريقة التي يتم بها تحديث أوزان نماذج Keras باستدعاءات متكررة لنفس وظيفة ConcreteFunction
.
استخدام مكررات ومولدات بايثون
تعتمد العديد من ميزات Python ، مثل المولدات والمكررات ، على وقت تشغيل Python لتتبع الحالة. بشكل عام ، بينما تعمل هذه التركيبات كما هو متوقع في الوضع الحثيث ، فهي أمثلة على الآثار الجانبية لبايثون وبالتالي تحدث فقط أثناء التتبع.
@tf.function
def buggy_consume_next(iterator):
tf.print("Value:", next(iterator))
iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1 Value: 1 Value: 1
تمامًا مثل كيفية امتلاك TensorFlow لـ tf.TensorArray
المتخصصة لبنى القائمة ، فإنه يحتوي على tf.data.Iterator
متخصص في إنشاءات التكرار. راجع قسم تحويلات AutoGraph للحصول على نظرة عامة. أيضًا ، يمكن أن تساعد واجهة برمجة تطبيقات tf.data
في تنفيذ أنماط المولد:
@tf.function
def good_consume_next(iterator):
# This is ok, iterator is a tf.data.Iterator
tf.print("Value:", next(iterator))
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1 Value: 2 Value: 3
يجب أن تكون جميع مخرجات دالة tf قيمًا مرجعة
باستثناء tf.Variable
s ، يجب أن تعيد دالة tf جميع مخرجاتها. محاولة الوصول مباشرة إلى أي موتر من دالة دون المرور بقيم الإرجاع يؤدي إلى حدوث "تسريبات".
على سبيل المثال ، الوظيفة الموجودة أسفل "تسرب" الموتر a
عبر Python global x
:
x = None
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
3 'Tensor' object has no attribute 'numpy'
هذا صحيح حتى إذا تم إرجاع القيمة المسربة أيضًا:
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
@tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
return b
with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
2 'Tensor' object has no attribute 'numpy' Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/566849597.py", line 21, in <module> captures_leaked_tensor(tf.constant(2)) TypeError: Originated from a graph execution error. The graph execution error is detected at a node built at (most recent call last): >>> File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main >>> File /usr/lib/python3.7/runpy.py, line 85, in _run_code >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module> >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start >>> File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever >>> File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once >>> File /usr/lib/python3.7/asyncio/events.py, line 88, in _run >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code >>> File /tmp/ipykernel_26244/566849597.py, line 7, in <module> >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__ >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler >>> File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2 >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__ Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.
عادةً ما تحدث مثل هذه التسريبات عند استخدام عبارات Python أو هياكل البيانات. بالإضافة إلى تسريب الموترات التي يتعذر الوصول إليها ، من المحتمل أيضًا أن تكون هذه العبارات خاطئة لأنها تعتبر آثارًا جانبية لـ Python ، ولا يتم ضمان تنفيذها عند كل استدعاء دالة.
تتضمن الطرق الشائعة لتسريب الموترات المحلية أيضًا تحوير مجموعة Python خارجية ، أو كائن:
class MyClass:
def __init__(self):
self.field = None
external_list = []
external_object = MyClass()
def leaky_function():
a = tf.constant(1)
external_list.append(a) # Bad - leaks tensor
external_object.field = a # Bad - leaks tensor
لا يتم دعم وظائف tf العودية
لا يتم دعم Function
التكرارية ويمكن أن تسبب حلقات لا نهائية. فمثلا،
@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1
with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Caught expected exception <class 'Exception'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module> recursive_fn(tf.constant(5)) # Bad - maximum recursion error. tensorflow.python.autograph.impl.api.StagingError: in user code: File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn * if n > 0: File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__ return _abc_instancecheck(cls, instance) RecursionError: maximum recursion depth exceeded while calling a Python object
حتى إذا بدت Function
العودية تعمل ، فسيتم تتبع وظيفة بيثون عدة مرات وقد يكون لها تأثير ضمني على الأداء. فمثلا،
@tf.function
def recursive_fn(n):
if n > 0:
print('tracing')
return recursive_fn(n - 1)
else:
return 1
recursive_fn(5) # Warning - multiple tracings
tracing tracing tracing tracing tracing <tf.Tensor: shape=(), dtype=int32, numpy=1>
مشاكل معروفة
إذا لم يتم تقييم Function
بشكل صحيح ، فقد يتم تفسير الخطأ من خلال هذه المشكلات المعروفة التي تم التخطيط لإصلاحها في المستقبل.
اعتمادًا على متغيرات Python العامة والحرة
تقوم Function
بإنشاء وظيفة ConcreteFunction
جديدة عند استدعائها بقيمة جديدة من وسيطة Python. ومع ذلك ، فإنه لا يفعل ذلك لإغلاق Python أو globals أو nonlocals لتلك Function
. إذا تغيرت قيمتها بين المكالمات إلى Function
، ستستمر Function
في استخدام القيم التي كانت لديها عندما تم تتبعها. هذا يختلف عن كيفية عمل وظائف Python العادية.
لهذا السبب ، يجب عليك اتباع أسلوب برمجة وظيفي يستخدم الوسائط بدلاً من إغلاق الأسماء الخارجية.
@tf.function
def buggy_add():
return 1 + foo
@tf.function
def recommended_add(foo):
return 1 + foo
foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add()) # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100! Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(101, shape=(), dtype=int32)
هناك طريقة أخرى لتحديث قيمة عامة ، وهي جعلها متغير tf.Variable
واستخدام طريقة Variable.assign
بدلاً من ذلك.
@tf.function
def variable_add():
return 1 + foo
foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100! Variable: tf.Tensor(101, shape=(), dtype=int32)
بالاعتماد على كائنات بايثون
التوصية لتمرير كائنات Python كوسائط إلى tf.function
لديها عدد من المشكلات المعروفة ، والتي من المتوقع إصلاحها في المستقبل. بشكل عام ، يمكنك الاعتماد على التتبع المتسق إذا كنت تستخدم بنية Python بدائية أو بنية متوافقة مع tf.nest
كوسيطة أو تمرير مثيل مختلف من كائن إلى Function
. ومع ذلك ، لن تقوم Function
بإنشاء تتبع جديد عندما تمرر نفس الكائن وتغير سماته فقط .
class SimpleModel(tf.Module):
def __init__(self):
# These values are *not* tf.Variables.
self.bias = 0.
self.weight = 2.
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x)) # Didn't change :(
Adding bias! tf.Tensor(20.0, shape=(), dtype=float32)
استخدام نفس Function
لتقييم المثيل المحدث للنموذج سيكون عربات التي تجرها الدواب لأن النموذج المحدث له نفس مفتاح التخزين المؤقت مثل النموذج الأصلي.
لهذا السبب ، يوصى بكتابة Function
لتجنب الاعتماد على سمات الكائن القابلة للتغيير أو إنشاء كائنات جديدة.
إذا لم يكن ذلك ممكنًا ، فإن أحد الحلول هو إنشاء Function
جديدة في كل مرة تقوم فيها بتعديل الكائن لفرض الاستعادة:
def evaluate(model, x):
return model.weight * x + model.bias
new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
نظرًا لأن الاسترداد يمكن أن يكون مكلفًا ، يمكنك استخدام tf.Variable
s كسمات كائن ، والتي يمكن تغييرها (ولكن لا يتم تغييرها ، بحذر!) للحصول على تأثير مماثل دون الحاجة إلى التصحيح.
class BetterModel:
def __init__(self):
self.bias = tf.Variable(0.)
self.weight = tf.Variable(2.)
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5
print(evaluate(better_model, x)) # This works!
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
إنشاء متغيرات tf
تدعم Function
فقط tf.Variable
التي تم إنشاؤها مرة واحدة في المكالمة الأولى ، وإعادة استخدامها عبر استدعاءات الوظائف اللاحقة. سينشئ مقتطف الشفرة أدناه tf.Variable
جديدًا في كل استدعاء دالة ، مما ينتج عنه استثناء ValueError
.
مثال:
@tf.function
def f(x):
v = tf.Variable(1.0)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module> f(1.0) ValueError: in user code: File "/tmp/ipykernel_26244/3018268426.py", line 3, in f * v = tf.Variable(1.0) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
النمط الشائع المستخدم للتغلب على هذا القيد هو البدء بقيمة Python None ، ثم إنشاء tf.Variable
إذا كانت القيمة لا شيء:
class Count(tf.Module):
def __init__(self):
self.count = None
@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)
return self.count.assign_add(1)
c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32)
باستخدام مع محسنات Keras متعددة
قد تواجه ValueError: tf.function only supports singleton tf.Variables created on the first call.
عند استخدام أكثر من مُحسِّن tf.function
مع وظيفة tf. يحدث هذا الخطأ لأن المحسّنين ينشئون داخليًا tf.Variables
عند تطبيق التدرجات لأول مرة.
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
@tf.function
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
train_step(w, x, y, opt2)
Calling `train_step` with different optimizer... Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module> train_step(w, x, y, opt2) ValueError: in user code: File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step * optimizer.apply_gradients(zip(gradients, [w])) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients ** self._create_all_weights(var_list) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights _ = self.iterations File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__ return super(OptimizerV2, self).__getattribute__(name) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight aggregation=aggregation) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable shape=variable_shape if variable_shape else None) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
إذا كنت بحاجة إلى تغيير المُحسِّن أثناء التدريب ، فإن الحل البديل هو إنشاء Function
جديدة لكل مُحسِّن ، واستدعاء الوظيفة ConcreteFunction
مباشرةً.
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
# Not a tf.function.
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
if i % 2 == 0:
train_step_1(w, x, y) # `opt1` is not used as a parameter.
else:
train_step_2(w, x, y) # `opt2` is not used as a parameter.
تستخدم مع نماذج Keras متعددة
قد تواجه أيضًا ValueError: tf.function only supports singleton tf.Variables created on the first call.
عند تمرير مثيلات نموذج مختلفة إلى نفس Function
.
يحدث هذا الخطأ لأن نماذج Keras (التي لم يتم تحديد شكل إدخالها ) وطبقات Keras تنشئ tf.Variables
s عندما يتم استدعاؤها لأول مرة. ربما تحاول تهيئة هذه المتغيرات داخل Function
تم استدعاؤها بالفعل. لتجنب هذا الخطأ ، حاول استدعاء model.build(input_shape)
لتهيئة كل الأوزان قبل تدريب النموذج.
قراءة متعمقة
للتعرف على كيفية تصدير Function
وتحميلها ، راجع دليل SavedModel . لمعرفة المزيد حول تحسينات الرسم البياني التي يتم إجراؤها بعد التتبع ، راجع دليل Grappler . لمعرفة كيفية تحسين خط أنابيب البيانات الخاص بك وملف تعريف النموذج الخاص بك ، راجع دليل منشئ ملفات التعريف .