ביצועים טובים יותר עם פונקצית tf

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHubהורד מחברת

ב-TensorFlow 2, ביצוע להוט מופעל כברירת מחדל. ממשק המשתמש אינטואיטיבי וגמיש (הפעלת פעולות חד פעמיות היא הרבה יותר קלה ומהירה), אבל זה יכול לבוא על חשבון הביצועים והפריסה.

אתה יכול להשתמש ב- tf.function כדי ליצור גרפים מהתוכניות שלך. זהו כלי טרנספורמציה שיוצר גרפי זרימת נתונים עצמאיים לפייתון מתוך קוד הפייתון שלך. זה יעזור לך ליצור מודלים ביצועיים וניידים, ויש צורך להשתמש ב- SavedModel .

מדריך זה יעזור לך להמשיג כיצד tf.function פועל מתחת למכסה המנוע, כך שתוכל להשתמש בו ביעילות.

הטייק אווי וההמלצות העיקריות הן:

  • בצע ניפוי באגים במצב להוט ולאחר מכן קשט ב- @tf.function .
  • אל תסתמך על תופעות לוואי של Python כמו מוטציה של אובייקט או הוספת רשימה.
  • tf.function עובד הכי טוב עם TensorFlow ops; קריאות NumPy ו- ​​Python מומרות לקבועים.

להכין

import tensorflow as tf

הגדר פונקציית עוזר כדי להדגים את סוגי השגיאות שאתה עלול להיתקל בהן:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

יסודות

נוֹהָג

Function שאתה מגדיר (למשל על ידי יישום ה- @tf.function decorator) היא בדיוק כמו פעולת TensorFlow ליבה: אתה יכול לבצע אותה בשקיקה; אתה יכול לחשב מעברי צבע; וכולי.

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

אתה יכול להשתמש Function בתוך Function אחרות.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function יכולות להיות מהירות יותר מקוד להוט, במיוחד עבור גרפים עם הרבה פעולות קטנות. אבל עבור גרפים עם כמה פעולות יקרות (כמו פיתולים), ייתכן שלא תראה מהירות רבה.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735
Function conv: 0.005791576000774512
Note how there's not much difference in performance for convolutions

מַעֲקָב

סעיף זה חושף כיצד Function פועלת מתחת למכסה המנוע, כולל פרטי יישום שעשויים להשתנות בעתיד . עם זאת, ברגע שאתה מבין מדוע ומתי מתרחש מעקב, הרבה יותר קל להשתמש ב- tf.function ביעילות!

מה זה "מעקב"?

Function מפעילה את התוכנית שלך בתרשים TensorFlow . עם זאת, tf.Graph לא יכול לייצג את כל הדברים שאתה כותב בתוכנית TensorFlow נלהבת. לדוגמה, Python תומך בפולימורפיזם, אך tf.Graph דורש שהקלטים שלו יהיו בעלי סוג וממד נתונים מוגדרים. או שאתה יכול לבצע משימות צד כמו קריאת ארגומנטים של שורת הפקודה, העלאת שגיאה או עבודה עם אובייקט Python מורכב יותר; אף אחד מהדברים האלה לא יכול לפעול ב- tf.Graph .

Function מגשרת על הפער הזה על ידי הפרדת הקוד שלך בשני שלבים:

1) בשלב הראשון, המכונה " מעקב ", Function יוצרת tf.Graph חדש. קוד Python פועל כרגיל, אבל כל פעולות TensorFlow (כמו הוספת שני Tensors) נדחות : הן נקלטות על ידי ה- tf.Graph ולא פועלות.

2) בשלב השני tf.Graph המכיל את כל מה שנדחה בשלב הראשון. שלב זה מהיר בהרבה משלב האיתור.

בהתאם לכניסות שלו, Function לא תמיד תפעיל את השלב הראשון כאשר היא נקראת. ראה "כללי מעקב" להלן כדי לקבל תחושה טובה יותר כיצד הוא מקבל את הקביעה הזו. דילוג על השלב הראשון וביצוע השלב השני בלבד הוא מה שנותן לך את הביצועים הגבוהים של TensorFlow.

כאשר Function אכן מחליטה להתחקות, שלב המעקב מלווה מיד את השלב השני, כך tf.Graph Function בהמשך תראה איך אתה יכול להפעיל רק את שלב המעקב עם get_concrete_function .

כאשר אתה מעביר ארגומנטים מסוגים שונים Function , שני השלבים מופעלים:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

שים לב שאם אתה קורא שוב ושוב Function עם אותו סוג ארגומנט, TensorFlow ידלג על שלב המעקב ויעשה שימוש חוזר בגרף שהתחקה בעבר, מכיוון שהגרף שנוצר יהיה זהה.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

אתה יכול להשתמש pretty_printed_concrete_signatures() כדי לראות את כל העקבות הזמינות:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

עד כה, ראית ש- tf.function יוצר שכבת שיגור דינאמית במטמון מעל לוגיקת מעקב הגרפים של TensorFlow. כדי להיות יותר ספציפי לגבי הטרמינולוגיה:

  • tf.Graph הוא הייצוג הגולמי, האגנוסטטי והנייד של חישוב TensorFlow.
  • ConcreteFunction עוטפת tf.Graph .
  • Function מנהלת מטמון של ConcreteFunction ובוחרת את המתאים עבור הקלט שלך.
  • tf.function עוטפת פונקציית Python, ומחזירה אובייקט Function .
  • מעקב יוצר tf.Graph ועוטף אותו ב- ConcreteFunction , הידוע גם בתור trace.

כללי איתור

Function קובעת אם לעשות שימוש חוזר ב- ConcreteFunction שהתחקה על-ידי חישוב מפתח מטמון מה-args וה-kwargs של קלט. מפתח מטמון הוא מפתח המזהה ConcreteFunction בהתבסס על הקלט args ו-kwargs של קריאת Function , בהתאם לכללים הבאים (שעשויים להשתנות):

  • המפתח שנוצר עבור tf.Tensor הוא הצורה וה-dtype שלו.
  • המפתח שנוצר עבור tf.Variable הוא מזהה משתנה ייחודי.
  • המפתח שנוצר עבור פרימיטיבי של Python (כמו int , float , str ) הוא הערך שלו.
  • המפתח שנוצר עבור dict s, list s, tuple s, namedtuple s ו- attr s הוא הטפול המשוטח של מפתחות העלים (ראה nest.flatten ). (כתוצאה מהשטחה זו, קריאה לפונקציה בטון בעלת מבנה קינון שונה מזה המשמש במהלך המעקב תגרום ל- TypeError).
  • עבור כל סוגי הפייתון האחרים המפתח הוא ייחודי לאובייקט. בדרך זו עוקבים אחר פונקציה או שיטה באופן עצמאי עבור כל מופע שאיתו הוא נקרא.

שליטה בחזרה

מעקב מחדש, כלומר כאשר Function שלך יוצרת יותר מעקיבה אחת, עוזרת להבטיח ש-TensorFlow יוצר גרפים נכונים עבור כל סט של קלט. עם זאת, איתור הוא פעולה יקרה! אם Function שלך חוזרת על גרף חדש עבור כל שיחה, תגלה שהקוד שלך פועל לאט יותר מאשר אם לא השתמשת ב- tf.function .

כדי לשלוט בהתנהגות המעקב, תוכל להשתמש בטכניקות הבאות:

  • ציין input_signature ב- tf.function כדי להגביל את המעקב.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
  • ציין ממד [ללא] ב- tf.TensorSpec כדי לאפשר גמישות בשימוש חוזר.

    מכיוון ש-TensorFlow מתאים טנזורים על סמך צורתם, שימוש בממד None כתו כללי יאפשר ל- Function s לעשות שימוש חוזר בעקבות עבור קלט בגדלים משתנים. קלט בגדלים משתנה יכול להתרחש אם יש לך רצפים באורך שונה, או תמונות בגדלים שונים עבור כל אצווה (ראה את המדריכים של Transformer ו- Deep Dream למשל).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • העבר ארגומנטים של Python לטנסורים כדי להפחית את החזרה.

    לעתים קרובות, ארגומנטים של Python משמשים לשליטה בפרמטרים של היפר ובנייני גרפים - לדוגמה, num_layers=10 או training=True או nonlinearity='relu' . לכן, אם הארגומנט של Python משתנה, הגיוני שתצטרך לחזור על הגרף.

    עם זאת, ייתכן שלא נעשה שימוש בארגומנט Python כדי לשלוט בבניית הגרפים. במקרים אלה, שינוי בערך Python יכול להפעיל מעקב מיותר. קחו, למשל, את לולאת האימון הזו, ש-AutoGraph תפתח באופן דינמי. למרות העקבות המרובות, הגרף שנוצר הוא למעשה זהה, כך שהחזרה לאחור היא מיותרת.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

אם אתה צריך לאחור בכוח, צור Function חדשה. מובטח שאובייקטי Function נפרדים לא ישתפו עקבות.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

השגת פונקציות קונקרטיות

בכל פעם שמתחקים אחר פונקציה, נוצרת פונקציה קונקרטית חדשה. אתה יכול לקבל ישירות פונקציה קונקרטית, באמצעות get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

הדפסת ConcreteFunction מציגה סיכום של ארגומנטים הקלט שלו (עם סוגים) וסוג הפלט שלו.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

אתה יכול גם לאחזר ישירות את החתימה של פונקציה קונקרטית.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

שימוש בעקבות בטון עם סוגים לא תואמים יגרום לשגיאה

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

ייתכן שתבחין כי ארגומנטים של Python מקבלים טיפול מיוחד בחתימת הקלט של פונקציה קונקרטית. לפני TensorFlow 2.3, ארגומנטים של Python פשוט הוסרו מהחתימה של הפונקציה הקונקרטית. החל מ- TensorFlow 2.3, ארגומנטים של Python נשארים בחתימה, אך הם מוגבלים לקחת את הערך שנקבע במהלך המעקב.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.

השגת גרפים

כל פונקציה קונקרטית היא עטיפה ניתנת להתקשרות סביב tf.Graph . למרות שאחזור אובייקט tf.Graph בפועל אינו משהו שבדרך כלל תצטרך לעשות, אתה יכול להשיג אותו בקלות מכל פונקציה קונקרטית.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

איתור באגים

באופן כללי, קוד ניפוי באגים קל יותר במצב להוט מאשר בתוך tf.function . עליך לוודא שהקוד שלך מופעל ללא שגיאות במצב להוט לפני שתעטר עם tf.function . כדי לסייע בתהליך איתור הבאגים, אתה יכול לקרוא tf.config.run_functions_eagerly(True) כדי להשבית ולהפעיל מחדש את tf.function באופן גלובלי.

בעת מעקב אחר בעיות המופיעות רק בתוך tf.function , הנה כמה טיפים:

  • קריאות print רגילות של Python מבוצעות רק במהלך המעקב, ועוזרות לך לאתר מתי הפונקציה שלך תתחקה (מחדש).
  • קריאות tf.print יבוצעו בכל פעם, ויכולות לעזור לך לאתר ערכי ביניים במהלך הביצוע.
  • tf.debugging.enable_check_numerics היא דרך קלה לאתר היכן נוצרים NaNs ו-Inf.
  • pdb ( מאתר הבאגים של Python ) יכול לעזור לך להבין מה קורה במהלך המעקב. (אזהרה: pdb יכניס אותך לקוד המקור שעבר טרנספורמציה של AutoGraph.)

טרנספורמציות של גרף אוטומטי

AutoGraph היא ספרייה שפועלת כברירת מחדל ב- tf.function , והופכת תת-קבוצה של קוד Python להוט ל-TensorFlow אופציות תואמות גרפים. זה כולל זרימת בקרה כמו if , for , while .

אופציות של TensorFlow כמו tf.cond ו- tf.while_loop ממשיכות לעבוד, אבל לרוב קל יותר לכתוב ולהבין את זרימת השליטה כשהיא כתובה ב-Python.

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753]
[0.582645297 0.613145649 0.619306684 0.319202513 0.182036072]
[0.524585426 0.546337605 0.550645113 0.308785647 0.18005164]
[0.481231302 0.497770309 0.501003504 0.299331933 0.178130865]
[0.447229207 0.460361809 0.462906033 0.290701121 0.176270396]
[0.419618756 0.430379033 0.432449728 0.282779962 0.174467146]
[0.396609187 0.405638 0.407366514 0.275476 0.172718227]
[0.377043903 0.384762734 0.386234313 0.268712848 0.17102097]
[0.360137492 0.366836458 0.368109286 0.262426734 0.169372901]
[0.345335096 0.351221472 0.352336824 0.256563932 0.167771652]
[0.332231969 0.337458342 0.338446289 0.251078814 0.166215062]
[0.320524871 0.325206399 0.326089561 0.24593246 0.164701089]
[0.309981436 0.314206958 0.31500268 0.241091311 0.163227797]
[0.300420195 0.304259449 0.304981351 0.236526251 0.161793426]
[0.291697085 0.295205742 0.295864582 0.232211992 0.160396278]
[0.283696055 0.286919087 0.287523568 0.228126258 0.159034774]
[0.276322395 0.279296666 0.27985391 0.224249557 0.157707423]
[0.269497961 0.272254 0.272769839 0.220564634 0.15641281]
[0.263157606 0.265720904 0.266200244 0.21705614 0.155149609]
[0.257246554 0.259638608 0.260085613 0.213710397 0.153916568]
[0.251718313 0.25395745 0.254375577 0.210515186 0.152712509]
[0.246533215 0.248635098 0.249027327 0.207459539 0.151536316]
[0.241657034 0.243635193 0.244004101 0.204533577 0.15038693]
[0.237060249 0.238926381 0.239274174 0.201728329 0.149263337]
[0.232717097 0.234481394 0.234810054 0.199035719 0.148164615]
[0.228605017 0.230276451 0.230587661 0.196448416 0.147089839]
[0.224704206 0.226290658 0.22658591 0.193959698 0.14603813]
[0.220997125 0.222505584 0.222786173 0.191563457 0.145008713]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077],
      dtype=float32)>

אם אתה סקרן אתה יכול לבדוק את חתימת הקוד שנוצרת.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

תנאים

AutoGraph ימיר כמה הצהרות if <condition> לקריאות המקבילות tf.cond . החלפה זו מתבצעת אם <condition> הוא Tensor. אחרת, המשפט if מבוצע כ-Python מותנה.

תנאי Python מבצע במהלך המעקב, כך שדווקא ענף אחד של התנאי יתווסף לגרף. ללא AutoGraph, הגרף המתחקה הזה לא יוכל לקחת את הענף החלופי אם יש זרימת בקרה תלוית נתונים.

tf.cond עוקב ומוסיף את שני הענפים של התנאי לגרף, תוך בחירה דינמית של ענף בזמן הביצוע. למעקב יכולות להיות תופעות לוואי לא מכוונות; בדוק את אפקטי המעקב של AutoGraph למידע נוסף.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

עיין בתיעוד ההתייחסות להגבלות נוספות על הצהרות אם מומרות אוטומטית.

לולאות

AutoGraph ימיר חלק מהצהרות for ו- while ל-TensorFlow לולאות אופציות המקבילות, כמו tf.while_loop . אם לא הומר, לולאת for או while מבוצעת כלולאת Python.

החלפה זו מתבצעת במצבים הבאים:

  • for x in y : אם y הוא טנסור, המר ל- tf.while_loop . במקרה המיוחד שבו y הוא tf.data.Dataset , נוצר שילוב של tf.data.Dataset ops.
  • while <condition> : אם <condition> הוא Tensor, המר ל- tf.while_loop .

לולאת Python מופעלת במהלך המעקב, ומוסיפה אופציות נוספות ל- tf.Graph עבור כל איטרציה של הלולאה.

לולאת TensorFlow עוקבת אחר גוף הלולאה, ובוחרת באופן דינמי כמה איטרציות לרוץ בזמן הביצוע. גוף הלולאה מופיע רק פעם אחת ב- tf.Graph שנוצר.

עיין בתיעוד ההפניה להגבלות נוספות על הצהרות AutoGraph שהומרו for ו- while .

לולאה מעל נתוני Python

מלכודת נפוצה היא לולאה מעל נתוני Python/NumPy בתוך tf.function . לולאה זו תבוצע במהלך תהליך המעקב, ותוסיף עותק של המודל שלך ל- tf.Graph עבור כל איטרציה של הלולאה.

אם ברצונך לעטוף את כל לולאת האימון ב- tf.function , הדרך הבטוחה ביותר לעשות זאת היא לעטוף את הנתונים שלך כ- tf.data.Dataset כך ש-AutoGraph יגולל באופן דינמי את לולאת האימון.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph

בעת גלישת נתוני Python/NumPy במערך נתונים, שים לב ל- tf.data.Dataset.from_generator לעומת tf.data.Dataset.from_tensors . הראשון ישמור את הנתונים ב-Python ויביא אותם דרך tf.py_function שיכולות להיות לו השלכות ביצועים, בעוד שהאחרון יצרור עותק של הנתונים tf.constant() אחד גדול בגרף, שיכול להיות בעל השלכות זיכרון.

קריאת נתונים מקבצים דרך TFRecordDataset , CsvDataset וכו' היא הדרך היעילה ביותר לצרוך נתונים, שכן אז TensorFlow בעצמה יכולה לנהל את הטעינה הא-סינכרונית ושליפה מראש של הנתונים, ללא צורך לערב את Python. למידע נוסף, עיין tf.data : בניית צינורות קלט של TensorFlow .

צבירת ערכים בלולאה

דפוס נפוץ הוא צבירת ערכי ביניים מלולאה. בדרך כלל, זה מושג על ידי הוספה לרשימת Python או הוספת ערכים למילון Python. עם זאת, מכיוון שמדובר בתופעות לוואי של Python, הן לא יפעלו כצפוי בלולאה מגולגלת דינמית. השתמש tf.TensorArray כדי לצבור תוצאות מלולאה שנפתחה באופן דינמי.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216],
        [0.44997275, 1.9107027 , 1.0716251 , 0.717237  ],
        [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]],

       [[0.04946005, 0.69127274, 0.56848884, 0.22406638],
        [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ],
        [0.9178308 , 1.320889  , 0.989761  , 2.0120025 ]]], dtype=float32)>

מגבלות

לפונקציית Function יש כמה מגבלות בתכנון שכדאי להיות מודע להן בעת ​​המרת Function Python לפונקציה.

ביצוע תופעות לוואי של Python

תופעות לוואי, כמו הדפסה, הוספה לרשימות ושינויים גלובליים, יכולות להתנהג בצורה בלתי צפויה בתוך Function , ולפעמים מופעלות פעמיים או לא את כולם. הם קורים רק בפעם הראשונה שאתה קורא Function עם קבוצה של כניסות. לאחר מכן, ה- tf.Graph מבוצע מחדש, מבלי להפעיל את קוד Python.

כלל האצבע הכללי הוא להימנע מהסתמכות על תופעות לוואי של Python בלוגיקה שלך ולהשתמש בהן רק כדי לנפות באגים שלך. אחרת, ממשקי API של TensorFlow כמו tf.data , tf.print , tf.summary , tf.Variable.assign ו- tf.TensorArray הם הדרך הטובה ביותר להבטיח שהקוד שלך יבוצע על ידי זמן הריצה של TensorFlow עם כל קריאה.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

אם תרצה להפעיל קוד של Python במהלך כל הפעלת Function , tf.py_function הוא פתח יציאה. החיסרון של tf.py_function הוא שהוא אינו נייד או בעל ביצועים מיוחדים, לא ניתן לשמור עם SavedModel, ואינו עובד היטב בהגדרות מבוזרות (רב-GPU, TPU). כמו כן, מכיוון tf.py_function צריך להיות מחובר לגרף, הוא מטיל את כל הכניסות/יציאות לטנזורים.

שינוי משתנים גלובליים וחופשיים של Python

שינוי של משתנים גלובליים וחופשיים של Python נחשב כתופעת לוואי של Python, כך שזה קורה רק במהלך המעקב.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

לפעמים קשה מאוד לשים לב להתנהגויות בלתי צפויות. בדוגמה שלהלן, counter נועד להגן על תוספת של משתנה. אולם מכיוון שהוא מספר שלם של פיתון ולא אובייקט TensorFlow, הערך שלו נקלט במהלך המעקב הראשון. כאשר נעשה שימוש assign_add tf.function ללא תנאי בגרף הבסיסי. לכן v יגדל ב-1, בכל פעם tf.function .. בעיה זו נפוצה בקרב משתמשים המנסים להעביר את קוד ה-Tensorflow במצב Grpah ל-Tensorflow 2 באמצעות tf.function , כאשר משתמשים בתופעות לוואי של פיתון ( counter בדוגמה) כדי לקבוע אילו פעולות להפעיל ( assign_add בדוגמה ). בדרך כלל, משתמשים מבינים זאת רק לאחר שהם רואים תוצאות מספריות חשודות, או ביצועים נמוכים משמעותית מהצפוי (למשל אם הפעולה המוגנת היא יקרה מאוד).

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

פתרון עוקף להשגת ההתנהגות הצפויה הוא שימוש ב- tf.init_scope כדי להרים את הפעולות מחוץ לגרף הפונקציות. זה מבטיח שהתוספת המשתנה תתבצע פעם אחת בלבד במהלך זמן המעקב. יש לשים לב ל- init_scope יש תופעות לוואי אחרות כולל זרימת בקרה מנוקת וסרט שיפוע. לפעמים השימוש ב- init_scope יכול להיות מורכב מכדי לנהל אותו בצורה מציאותית.

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

לסיכום, ככלל אצבע, עליך להימנע משינוי של אובייקטי פיתון כמו מספרים שלמים או מיכלים כמו רשימות שחיות מחוץ Function . במקום זאת, השתמש בארגומנטים ובאובייקטי TF. לדוגמה, בסעיף "צבירת ערכים בלולאה" יש דוגמה אחת לאופן שבו ניתן ליישם פעולות דמויי רשימה.

אתה יכול, במקרים מסוימים, ללכוד ולתפעל מצב אם זה tf.Variable . כך מתעדכנים המשקלים של דגמי Keras בקריאות חוזרות לאותה ConcreteFunction .

שימוש באיטרטורים ומחוללים של Python

תכונות רבות של Python, כגון גנרטורים ואיטרטורים, מסתמכים על זמן הריצה של Python כדי לעקוב אחר המצב. באופן כללי, בעוד שהמבנים הללו פועלים כצפוי במצב להוט, הם מהווים דוגמאות לתופעות לוואי של Python ולכן קורות רק במהלך המעקב.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

בדיוק כמו ל-TensorFlow יש tf.TensorArray מיוחד לבניית רשימה, יש לו tf.data.Iterator מיוחד לבניית איטרציה. עיין בסעיף על טרנספורמציות אוטוגרף לסקירה כללית. כמו כן, ממשק ה-API של tf.data יכול לסייע ביישום דפוסי מחולל:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

כל הפלטים של פונקציה tf. חייבים להיות ערכי החזרה

למעט tf.Variable s, פונקציה tf. חייבת להחזיר את כל הפלטים שלה. ניסיון לגשת ישירות לכל טנסור מפונקציה מבלי לעבור על ערכי החזרה גורם ל"דליפות".

לדוגמה, הפונקציה למטה "מדליפה" את הטנזור a דרך ה-Python global x :

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'Tensor' object has no attribute 'numpy'

זה נכון גם אם הערך הדלף מוחזר גם:

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'Tensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
>>>  File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main
>>>  File /usr/lib/python3.7/runpy.py, line 85, in _run_code
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once
>>>  File /usr/lib/python3.7/asyncio/events.py, line 88, in _run
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code
>>>  File /tmp/ipykernel_26244/566849597.py, line 7, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler
>>>  File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__

Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

בדרך כלל, דליפות כאלה מתרחשות כאשר אתה משתמש בהצהרות Python או במבני נתונים. בנוסף להדלפת טנסורים בלתי נגישים, סביר להניח שהצהרות כאלה שגויות מכיוון שהן נחשבות כתופעות לוואי של Python, ואינן מובטחות שיבוצעו בכל קריאת פונקציה.

דרכים נפוצות להדלפת טנסורים מקומיים כוללות גם מוטציה של אוסף פייתון חיצוני, או אובייקט:

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

פונקציות tf. רקורסיביות אינן נתמכות

Function רקורסיביות אינן נתמכות ועלולות לגרום ללולאות אינסופיות. לדוגמה,

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__
        return _abc_instancecheck(cls, instance)

    RecursionError: maximum recursion depth exceeded while calling a Python object

גם אם נראה שפונקציה רקורסיבית עובדת, Function הפיתון תתחקה מספר פעמים ויכולה להיות לה השלכה על ביצועים. לדוגמה,

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

בעיות ידועות

אם Function שלך אינה מוערכת כראוי, ייתכן שהשגיאה מוסברת על ידי בעיות ידועות אלו המתוכננות לתקן בעתיד.

תלוי במשתנים גלובליים וחופשיים של Python

Function יוצרת ConcreteFunction חדשה כאשר היא נקראת עם ערך חדש של ארגומנט Python. עם זאת, זה לא עושה את זה עבור סגירת Python, גלובלים או לא-מקומיים של Function זו. אם הערך שלהם משתנה בין קריאות ל- Function , Function עדיין תשתמש בערכים שהיו להם בזמן המעקב אחריה. זה שונה מהאופן שבו פועלות פונקציות Python רגילות.

מסיבה זו, עליך לעקוב אחר סגנון תכנות פונקציונלי המשתמש בארגומנטים במקום לסגור על שמות חיצוניים.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

דרך נוספת לעדכן ערך גלובלי, היא להפוך אותו ל- tf.Variable ולהשתמש במקום זאת בשיטת Variable.assign .

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

תלוי באובייקטים של Python

להמלצה להעביר אובייקטים של Python כארגומנטים לתוך tf.function יש מספר בעיות ידועות, שצפויות להתוקן בעתיד. באופן כללי, אתה יכול להסתמך על מעקב עקבי אם אתה משתמש במבנה פרימיטיבי של Python או tf.nest כארגומנט או מעביר במופע אחר של אובייקט לתוך Function . עם זאת, Function לא תיצור מעקב חדש כאשר תעביר את אותו אובייקט ורק תשנה את התכונות שלו .

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

השימוש באותה Function כדי להעריך את המופע המעודכן של המודל יהיה באגי מכיוון שלדגם המעודכן יש מפתח מטמון זהה לדגם המקורי.

מסיבה זו, מומלץ לכתוב את Function כדי להימנע מהתלות בתכונות אובייקט הניתנות לשינוי או ליצור אובייקטים חדשים.

אם זה לא אפשרי, דרך אחת לעקיפת הבעיה היא ליצור Function חדשות בכל פעם שאתה משנה את האובייקט שלך כדי לאחור בכוח:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

מכיוון שהחזרה יכולה להיות יקרה , אתה יכול להשתמש ב- tf.Variable s כתכונות אובייקט, אותן ניתן לבצע מוטציה (אך לא לשנות, זהירות!) לקבלת אפקט דומה ללא צורך בחזרה.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

יצירת tf.Variables

Function תומכת רק ב-singleton tf.Variable s שנוצרו פעם אחת בקריאה הראשונה, ונעשה בהם שימוש חוזר בקריאות פונקציה עוקבות. קטע הקוד שלמטה ייצור tf.Variable חדש בכל קריאת פונקציה, מה שמביא לחריג ValueError .

דוגמא:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

דפוס נפוץ המשמש לעקוף מגבלה זו הוא להתחיל בערך Python None, ולאחר מכן ליצור באופן מותנה את ה- tf.Variable אם הערך הוא None:

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

שימוש עם מספר אופטימיזציית Keras

אתה עלול להיתקל ValueError: tf.function only supports singleton tf.Variables created on the first call. כאשר משתמשים ביותר ממייעל Keras אחד עם tf.function . שגיאה זו מתרחשת מכיוון שמייעלים יוצרים באופן פנימי tf.Variables כאשר הם מיישמים מעברי צבע בפעם הראשונה.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients  **
        self._create_all_weights(var_list)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights
        _ = self.iterations
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight
        aggregation=aggregation)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable
        shape=variable_shape if variable_shape else None)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

אם אתה צריך לשנות את האופטימיזציה במהלך האימון, הדרך לעקיפת הבעיה היא ליצור Function חדשה עבור כל אופטימיזציה, להתקשר ישירות ל- ConcreteFunction .

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

שימוש עם מספר דגמי Keras

אתה עלול גם להיתקל ValueError: tf.function only supports singleton tf.Variables created on the first call. כאשר מעבירים מופעי מודל שונים לאותה Function .

שגיאה זו מתרחשת מכיוון שמודלים של Keras ( שאין להם צורת קלט מוגדרת ) ושכבות Keras יוצרות tf.Variables כאשר הם נקראים לראשונה. ייתכן שאתה מנסה לאתחל את המשתנים האלה בתוך Function , שכבר נקראה. כדי להימנע משגיאה זו, נסה לקרוא model.build(input_shape) כדי לאתחל את כל המשקולות לפני אימון המודל.

לקריאה נוספת

כדי ללמוד כיצד לייצא ולטעון Function , עיין במדריך SavedModel . למידע נוסף על אופטימיזציות של גרפים שמתבצעות לאחר מעקב, עיין במדריך גרפלר . כדי ללמוד כיצד לבצע אופטימיזציה של צינור הנתונים שלך ולעשות פרופיל של המודל שלך, עיין במדריך Profiler .