Xem trên TensorFlow.org | Chạy trong Google Colab | Xem nguồn trên GitHub | Tải xuống sổ ghi chép |
Trong TensorFlow 2, thực thi háo hức được bật theo mặc định. Giao diện người dùng trực quan và linh hoạt (chạy các hoạt động một lần dễ dàng hơn và nhanh hơn nhiều), nhưng điều này có thể phải trả giá bằng hiệu suất và khả năng triển khai.
Bạn có thể sử dụng tf.function
để tạo đồ thị từ các chương trình của mình. Nó là một công cụ chuyển đổi tạo đồ thị luồng dữ liệu độc lập với Python từ mã Python của bạn. Điều này sẽ giúp bạn tạo các mô hình hiệu quả và di động, đồng thời bắt buộc phải sử dụng SavedModel
.
Hướng dẫn này sẽ giúp bạn hình dung cách hoạt động của tf.function
, vì vậy bạn có thể sử dụng nó một cách hiệu quả.
Các điểm rút ra và khuyến nghị chính là:
- Gỡ lỗi ở chế độ háo hức, sau đó trang trí với
@tf.function
. function. - Đừng dựa vào các tác dụng phụ của Python như đột biến đối tượng hoặc nối thêm danh sách.
-
tf.function
hoạt động tốt nhất với hoạt động TensorFlow; Các cuộc gọi NumPy và Python được chuyển đổi thành hằng số.
Thành lập
import tensorflow as tf
Xác định một chức năng trợ giúp để chứng minh các loại lỗi bạn có thể gặp phải:
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
Khái niệm cơ bản
Cách sử dụng
Một Function
bạn xác định (ví dụ bằng cách áp dụng trình trang trí @tf.function
năng) cũng giống như một hoạt động TensorFlow cốt lõi: Bạn có thể thực thi nó một cách hăng hái; bạn có thể tính toán độ dốc; và như thế.
@tf.function # The decorator converts `add` into a `Function`.
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy= array([[2., 2.], [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
Bạn có thể sử dụng các Function
bên trong các Function
khác.
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy= array([[3., 3.], [3., 3.], [3., 3.]], dtype=float32)>
Function
s có thể nhanh hơn mã háo hức, đặc biệt đối với đồ thị có nhiều hoạt động nhỏ. Nhưng đối với các đồ thị có một vài hoạt động đắt tiền (như chập), bạn có thể không thấy tốc độ tăng lên nhiều.
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735 Function conv: 0.005791576000774512 Note how there's not much difference in performance for convolutions
Truy tìm
Phần này trình bày cách Function
hoạt động ngầm, bao gồm các chi tiết triển khai có thể thay đổi trong tương lai . Tuy nhiên, một khi bạn hiểu tại sao và khi nào thì việc truy tìm xảy ra, việc sử dụng tf.function
một cách hiệu quả sẽ dễ dàng hơn nhiều!
"Truy tìm" là gì?
Một Function
chạy chương trình của bạn trong một Đồ thị TensorFlow . Tuy nhiên, một tf.Graph
không thể đại diện cho tất cả những thứ bạn viết trong một chương trình TensorFlow háo hức. Ví dụ: Python hỗ trợ tính đa hình, nhưng tf.Graph
yêu cầu đầu vào của nó phải có kiểu dữ liệu và thứ nguyên được chỉ định. Hoặc bạn có thể thực hiện các tác vụ phụ như đọc các đối số dòng lệnh, nêu lỗi hoặc làm việc với một đối tượng Python phức tạp hơn; không có thứ nào trong số này có thể chạy trong tf.Graph
.
Function
thu hẹp khoảng cách này bằng cách tách mã của bạn thành hai giai đoạn:
1) Trong giai đoạn đầu tiên, được gọi là " truy tìm ", Function
tạo một tf.Graph
mới. Mã Python chạy bình thường, nhưng tất cả các hoạt động TensorFlow (như thêm hai Tensor) đều bị hoãn lại : chúng bị tf.Graph
nắm bắt và không chạy.
2) Trong giai đoạn thứ hai, một tf.Graph
chứa mọi thứ đã bị trì hoãn trong giai đoạn đầu tiên được chạy. Giai đoạn này nhanh hơn nhiều so với giai đoạn truy tìm.
Tùy thuộc vào các đầu vào của nó, Function
sẽ không phải lúc nào cũng chạy ở giai đoạn đầu tiên khi nó được gọi. Xem "Quy tắc theo dõi" bên dưới để hiểu rõ hơn về cách nó đưa ra quyết định đó. Bỏ qua giai đoạn đầu tiên và chỉ thực hiện giai đoạn thứ hai là những gì mang lại cho bạn hiệu suất cao của TensorFlow.
Khi Function
quyết định theo dõi, giai đoạn theo dõi ngay lập tức được theo sau bởi giai đoạn thứ hai, vì vậy việc gọi Function
vừa tạo và chạy tf.Graph
. Sau đó, bạn sẽ thấy cách bạn có thể chạy chỉ giai đoạn theo dõi với get_concrete_function
.
Khi bạn chuyển các đối số của các kiểu khác nhau vào một Function
, cả hai giai đoạn đều được chạy:
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32) Tracing with Tensor("a:0", shape=(), dtype=float32) tf.Tensor(2.2, shape=(), dtype=float32) Tracing with Tensor("a:0", shape=(), dtype=string) tf.Tensor(b'aa', shape=(), dtype=string)
Lưu ý rằng nếu bạn gọi nhiều lần một Function
có cùng loại đối số, TensorFlow sẽ bỏ qua giai đoạn theo dõi và sử dụng lại một biểu đồ đã theo dõi trước đó, vì biểu đồ được tạo sẽ giống hệt nhau.
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)
Bạn có thể sử dụng pretty_printed_concrete_signatures()
để xem tất cả các dấu vết có sẵn:
print(double.pretty_printed_concrete_signatures())
double(a) Args: a: int32 Tensor, shape=() Returns: int32 Tensor, shape=() double(a) Args: a: float32 Tensor, shape=() Returns: float32 Tensor, shape=() double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
Cho đến nay, bạn đã thấy rằng tf.function
tạo ra một lớp điều phối động, được lưu trong bộ nhớ cache trên logic theo dõi đồ thị của TensorFlow. Để cụ thể hơn về thuật ngữ:
- Một
tf.Graph
là bản đại diện thô, không có ngôn ngữ, di động của một phép tính TensorFlow. - Một
ConcreteFunction
kết thúc mộttf.Graph
. -
Function
quản lý một bộ nhớ cache của cácConcreteFunction
và chọn một bộ nhớ đệm phù hợp cho đầu vào của bạn. -
tf.function
kết thúc một hàm Python, trả về một đối tượngFunction
. - Tracing tạo ra một
tf.Graph
và bao bọc nó trong mộtConcreteFunction
, còn được gọi là một dấu vết.
Quy tắc truy tìm
Một Function
xác định xem có sử dụng lại ConcreteFunction
được theo dõi hay không bằng cách tính toán khóa bộ nhớ cache từ các args và kwargs của đầu vào. Khóa bộ nhớ cache là khóa xác định ConcreteFunction
dựa trên các args và kwargs đầu vào của lệnh gọi Function
, theo các quy tắc sau (có thể thay đổi):
- Chìa khóa được tạo ra cho
tf.Tensor
là hình dạng và kiểu của nó. - Khóa được tạo cho
tf.Variable
là một id biến duy nhất. - Khóa được tạo cho một nguyên thủy Python (như
int
,float
,str
) là giá trị của nó. - Khóa được tạo cho các
dict
lồng nhau,list
s,tuple
s,namedtuple
s vàattr
s là bộ khóa lá được làm phẳng (xemnest.flatten
). (Kết quả của việc làm phẳng này, việc gọi một hàm bê tông có cấu trúc lồng khác với cấu trúc được sử dụng trong quá trình truy tìm sẽ dẫn đến Lỗi loại). - Đối với tất cả các kiểu Python khác, khóa là duy nhất cho đối tượng. Bằng cách này, một hàm hoặc phương thức được truy tìm độc lập cho mỗi trường hợp mà nó được gọi.
Kiểm soát việc kiểm tra lại
Retracing, đó là khi Function
của bạn tạo nhiều hơn một dấu vết, giúp đảm bảo rằng TensorFlow tạo ra các đồ thị chính xác cho từng nhóm đầu vào. Tuy nhiên, truy tìm là một hoạt động tốn kém! Nếu Function
của bạn truy xuất lại một biểu đồ mới cho mỗi lần gọi, bạn sẽ thấy rằng mã của bạn thực thi chậm hơn so với khi bạn không sử dụng tf.function
. function.
Để kiểm soát hành vi theo dõi, bạn có thể sử dụng các kỹ thuật sau:
- Chỉ định
input_signature
trongtf.function
để hạn chế việc theo dõi.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([4 1], shape=(2,), dtype=int32) Caught expected exception <class 'ValueError'>: Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module> next_collatz(tf.constant([[1, 2], [3, 4]])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor( [[1 2] [3 4]], shape=(2, 2), dtype=int32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)). Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module> next_collatz(tf.constant([1.0, 2.0])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor([1. 2.], shape=(2,), dtype=float32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Chỉ định thứ nguyên [Không có] trong
tf.TensorSpec
để cho phép sử dụng lại dấu vết một cách linh hoạt.Vì TensorFlow so khớp các tensor dựa trên hình dạng của chúng, việc sử dụng thứ nguyên
None
làm ký tự đại diện sẽ cho phép cácFunction
sử dụng lại dấu vết cho đầu vào có kích thước thay đổi. Đầu vào có kích thước khác nhau có thể xảy ra nếu bạn có các chuỗi có độ dài khác nhau hoặc hình ảnh có kích thước khác nhau cho mỗi lô (Ví dụ: Xem hướng dẫn về Transformer và Deep Dream ).
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([1 2 3], shape=(3,), dtype=int32) tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
Truyền các đối số Python đến Tensors để giảm việc rút lại.
Thông thường, các đối số trong Python được sử dụng để kiểm soát các siêu tham số và cấu trúc đồ thị - ví dụ:
num_layers=10
hoặctraining=True
hoặcnonlinearity='relu'
. Vì vậy, nếu đối số Python thay đổi, có nghĩa là bạn phải truy xuất lại biểu đồ.Tuy nhiên, có thể một đối số Python không được sử dụng để kiểm soát việc xây dựng đồ thị. Trong những trường hợp này, một sự thay đổi trong giá trị Python có thể kích hoạt việc kiểm tra lại không cần thiết. Lấy ví dụ, vòng lặp đào tạo này, AutoGraph sẽ tự động hủy cuộn. Mặc dù có nhiều dấu vết, nhưng biểu đồ được tạo thực sự giống hệt nhau, vì vậy việc kiểm tra lại là không cần thiết.
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
for _ in tf.range(num_steps):
train_one_step()
print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)
print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments. Tracing with num_steps = 10 Executing with num_steps = 10 Tracing with num_steps = 20 Executing with num_steps = 20 Traces are reused for Tensor arguments. Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32) Executing with num_steps = 10 Executing with num_steps = 20
Nếu bạn cần buộc phải rút lại, hãy tạo một Function
mới. Các đối tượng Function
riêng biệt được đảm bảo không chia sẻ dấu vết.
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Tracing! Executing Tracing! Executing
Có được các chức năng cụ thể
Mỗi khi một chức năng được truy tìm, một chức năng cụ thể mới được tạo ra. Bạn có thể lấy trực tiếp một hàm cụ thể bằng cách sử dụng get_concrete_function
.
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace Executing traced function tf.Tensor(b'aa', shape=(), dtype=string) tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)
Việc in một ConcreteFunction
hiển thị một bản tóm tắt các đối số đầu vào của nó (với các kiểu) và kiểu đầu ra của nó.
print(double_strings)
ConcreteFunction double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
Bạn cũng có thể truy xuất trực tiếp chữ ký của một hàm cụ thể.
print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {}) Tensor("Identity:0", shape=(), dtype=string)
Sử dụng dấu vết cụ thể với các loại không tương thích sẽ gây ra lỗi
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
Caught expected exception <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module> double_strings(tf.constant(1)) tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]
Bạn có thể nhận thấy rằng các đối số Python được xử lý đặc biệt trong chữ ký đầu vào của một hàm cụ thể. Trước TensorFlow 2.3, các đối số trong Python chỉ đơn giản là bị xóa khỏi chữ ký của hàm cụ thể. Bắt đầu với TensorFlow 2.3, các đối số Python vẫn còn trong chữ ký, nhưng bị hạn chế lấy giá trị được đặt trong quá trình truy tìm.
@tf.function
def pow(a, b):
return a ** b
square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2) Args: a: float32 Tensor, shape=<unknown> Returns: float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100
with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl cancellation_manager) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature raise TypeError(f"{self._flat_signature_summary()} got unexpected " TypeError: pow(a) got unexpected keyword arguments: b. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module> square(tf.constant(10.0), b=3) TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.
Lấy đồ thị
Mỗi hàm cụ thể là một trình bao bọc có thể gọi xung quanh một tf.Graph
. Mặc dù truy xuất đối tượng tf.Graph
thực tế không phải là điều bạn thường cần làm, nhưng bạn có thể lấy nó dễ dàng từ bất kỳ hàm cụ thể nào.
graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
[] -> a ['a', 'a'] -> add ['add'] -> Identity
Gỡ lỗi
Nói chung, mã gỡ lỗi trong chế độ háo hức dễ dàng hơn bên trong tf.function
. Chức năng. Bạn nên đảm bảo rằng mã của bạn thực thi không có lỗi ở chế độ háo hức trước khi trang trí bằng tf.function
. function. Để hỗ trợ quá trình gỡ lỗi, bạn có thể gọi tf.config.run_functions_eagerly(True)
để tắt và kích tf.function
. Chức năng trên toàn cầu.
Khi theo dõi các vấn đề chỉ xuất hiện trong tf.function
. function, đây là một số mẹo:
- Các lệnh gọi
print
Python cũ đơn thuần chỉ thực thi trong quá trình truy tìm, giúp bạn theo dõi khi hàm của bạn được truy tìm (lại). - Các lệnh gọi
tf.print
sẽ thực hiện mọi lúc và có thể giúp bạn theo dõi các giá trị trung gian trong quá trình thực thi. -
tf.debugging.enable_check_numerics
là một cách dễ dàng để theo dõi nơi tạo NaN và Inf. -
pdb
( trình gỡ lỗi Python ) có thể giúp bạn hiểu những gì đang xảy ra trong quá trình truy tìm. (Lưu ý:pdb
sẽ đưa bạn vào mã nguồn được chuyển đổi AutoGraph.)
Các phép biến đổi AutoGraph
AutoGraph là một thư viện được bật theo mặc định trong tf.function
. function và chuyển đổi một tập hợp con của mã háo hức Python thành các hoạt động TensorFlow tương thích với đồ thị. Điều này bao gồm luồng điều khiển như if
, for
, while
.
Các hoạt động TensorFlow như tf.cond
và tf.while_loop
tiếp tục hoạt động, nhưng luồng điều khiển thường dễ viết và dễ hiểu hơn khi được viết bằng Python.
# A simple loop
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753] [0.582645297 0.613145649 0.619306684 0.319202513 0.182036072] [0.524585426 0.546337605 0.550645113 0.308785647 0.18005164] [0.481231302 0.497770309 0.501003504 0.299331933 0.178130865] [0.447229207 0.460361809 0.462906033 0.290701121 0.176270396] [0.419618756 0.430379033 0.432449728 0.282779962 0.174467146] [0.396609187 0.405638 0.407366514 0.275476 0.172718227] [0.377043903 0.384762734 0.386234313 0.268712848 0.17102097] [0.360137492 0.366836458 0.368109286 0.262426734 0.169372901] [0.345335096 0.351221472 0.352336824 0.256563932 0.167771652] [0.332231969 0.337458342 0.338446289 0.251078814 0.166215062] [0.320524871 0.325206399 0.326089561 0.24593246 0.164701089] [0.309981436 0.314206958 0.31500268 0.241091311 0.163227797] [0.300420195 0.304259449 0.304981351 0.236526251 0.161793426] [0.291697085 0.295205742 0.295864582 0.232211992 0.160396278] [0.283696055 0.286919087 0.287523568 0.228126258 0.159034774] [0.276322395 0.279296666 0.27985391 0.224249557 0.157707423] [0.269497961 0.272254 0.272769839 0.220564634 0.15641281] [0.263157606 0.265720904 0.266200244 0.21705614 0.155149609] [0.257246554 0.259638608 0.260085613 0.213710397 0.153916568] [0.251718313 0.25395745 0.254375577 0.210515186 0.152712509] [0.246533215 0.248635098 0.249027327 0.207459539 0.151536316] [0.241657034 0.243635193 0.244004101 0.204533577 0.15038693] [0.237060249 0.238926381 0.239274174 0.201728329 0.149263337] [0.232717097 0.234481394 0.234810054 0.199035719 0.148164615] [0.228605017 0.230276451 0.230587661 0.196448416 0.147089839] [0.224704206 0.226290658 0.22658591 0.193959698 0.14603813] [0.220997125 0.222505584 0.222786173 0.191563457 0.145008713] <tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077], dtype=float32)>
Nếu bạn tò mò, bạn có thể kiểm tra mã tạo ra chữ ký.
print(tf.autograph.to_code(f.python_function))
def tf__f(x): with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope: do_return = False retval_ = ag__.UndefinedReturnValue() def get_state(): return (x,) def set_state(vars_): nonlocal x (x,) = vars_ def loop_body(): nonlocal x ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope) x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope) def loop_test(): return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1) ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {}) try: do_return = True retval_ = ag__.ld(x) except: do_return = False raise return fscope.ret(retval_, do_return)
Điều kiện
AutoGraph sẽ chuyển đổi một số câu lệnh if <condition>
thành các lệnh gọi tf.cond
tương đương. Sự thay thế này được thực hiện nếu <condition>
là một Tensor. Ngược lại, câu lệnh if
được thực thi dưới dạng điều kiện Python.
Một điều kiện Python thực thi trong quá trình theo dõi, vì vậy chính xác một nhánh của điều kiện sẽ được thêm vào biểu đồ. Nếu không có AutoGraph, biểu đồ theo dõi này sẽ không thể lấy nhánh thay thế nếu có luồng điều khiển phụ thuộc vào dữ liệu.
tf.cond
theo dõi và thêm cả hai nhánh của điều kiện vào biểu đồ, chọn động một nhánh tại thời điểm thực thi. Truy tìm có thể có tác dụng phụ ngoài ý muốn; kiểm tra các hiệu ứng theo dõi AutoGraph để biết thêm thông tin.
@tf.function
def fizzbuzz(n):
for i in tf.range(1, n + 1):
print('Tracing for loop')
if i % 15 == 0:
print('Tracing fizzbuzz branch')
tf.print('fizzbuzz')
elif i % 3 == 0:
print('Tracing fizz branch')
tf.print('fizz')
elif i % 5 == 0:
print('Tracing buzz branch')
tf.print('buzz')
else:
print('Tracing default branch')
tf.print(i)
fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop Tracing fizzbuzz branch Tracing fizz branch Tracing buzz branch Tracing default branch 1 2 fizz 4 buzz 1 2 fizz 4 buzz fizz 7 8 fizz buzz 11 fizz 13 14 fizzbuzz 16 17 fizz 19 buzz
Xem tài liệu tham khảo để biết các hạn chế bổ sung đối với các câu lệnh if được chuyển đổi tự động.
Vòng lặp
AutoGraph sẽ chuyển đổi một số câu lệnh for
và while
thành các hoạt động lặp TensorFlow tương đương, như tf. tf.while_loop
. Nếu không được chuyển đổi, vòng lặp for
hoặc while
được thực thi như một vòng lặp Python.
Sự thay thế này được thực hiện trong các trường hợp sau:
-
for x in y
: nếuy
là Tensor, hãy chuyển đổi thành tf.tf.while_loop
. Trong trường hợp đặc biệt khiy
làtf.data.Dataset
, kết hợp các hoạttf.data.Dataset
được tạo ra. -
while <condition>
: nếu<condition>
là Tensor, hãy chuyển đổi thành tf.tf.while_loop
.
Một vòng lặp Python thực thi trong quá trình truy tìm, thêm các hoạt động bổ sung vào tf.Graph
cho mỗi lần lặp lại của vòng lặp.
Một vòng lặp TensorFlow theo dõi phần nội dung của vòng lặp và tự động chọn số lần lặp để chạy tại thời điểm thực thi. Phần thân của vòng lặp chỉ xuất hiện một lần trong tf.Graph
tạo.
Xem tài liệu tham khảo để biết thêm các hạn chế đối với các câu lệnh for
và while
được chuyển đổi tự động.
Vòng qua dữ liệu Python
Một lỗi phổ biến là lặp qua dữ liệu Python / NumPy trong một tf.function
. Vòng lặp này sẽ thực thi trong quá trình truy tìm, thêm một bản sao mô hình của bạn vào tf.Graph
cho mỗi lần lặp lại của vòng lặp.
Nếu bạn muốn gói toàn bộ vòng huấn luyện trong tf.function
. function, cách an toàn nhất để thực hiện việc này là bọc dữ liệu của bạn dưới dạng tf.data.Dataset
để AutoGraph sẽ tự động bỏ cuộn vòng huấn luyện.
def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) contains {} nodes in its graph".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
Khi gói dữ liệu Python / NumPy trong Dataset, hãy lưu ý đến tf.data.Dataset.from_generator
so với tf.data.Dataset.from_tensors
. Cái trước sẽ giữ dữ liệu bằng Python và tìm nạp nó qua tf.py_function
năng có thể có ý nghĩa về hiệu suất, trong khi cái sau sẽ gói một bản sao dữ liệu dưới dạng một nút tf.constant()
lớn trong biểu đồ, có thể có ý nghĩa về bộ nhớ.
Đọc dữ liệu từ các tệp thông qua TFRecordDataset
, CsvDataset
, v.v. là cách hiệu quả nhất để tiêu thụ dữ liệu, vì bản thân TensorFlow có thể quản lý tải không đồng bộ và tìm nạp trước dữ liệu mà không cần phải liên quan đến Python. Để tìm hiểu thêm, hãy xem tf.data
: Hướng dẫn xây dựng đường ống đầu vào TensorFlow .
Tích lũy các giá trị trong một vòng lặp
Một mô hình phổ biến là tích lũy các giá trị trung gian từ một vòng lặp. Thông thường, điều này được thực hiện bằng cách thêm vào danh sách Python hoặc thêm các mục nhập vào từ điển Python. Tuy nhiên, vì đây là các tác dụng phụ của Python, chúng sẽ không hoạt động như mong đợi trong một vòng lặp được mở động. Sử dụng tf.TensorArray
để tích lũy kết quả từ một vòng lặp không được cuộn động.
batch_size = 2
seq_len = 3
feature_size = 4
def rnn_step(inp, state):
return inp + state
@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
max_seq_len = input_data.shape[0]
states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
return tf.transpose(states.stack(), [1, 0, 2])
dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy= array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216], [0.44997275, 1.9107027 , 1.0716251 , 0.717237 ], [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]], [[0.04946005, 0.69127274, 0.56848884, 0.22406638], [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ], [0.9178308 , 1.320889 , 0.989761 , 2.0120025 ]]], dtype=float32)>
Hạn chế
Function
TensorFlow có một vài hạn chế theo thiết kế mà bạn nên biết khi chuyển đổi một hàm Python thành một Function
.
Thực thi các tác dụng phụ của Python
Các tác dụng phụ, như in, thêm vào danh sách và thay đổi hình cầu, có thể hoạt động không mong muốn bên trong một Function
, đôi khi thực thi hai lần hoặc không phải tất cả. Chúng chỉ xảy ra lần đầu tiên bạn gọi một Function
với một tập hợp các đầu vào. Sau đó, tf.Graph
được truy tìm được thực thi lại mà không thực thi mã Python.
Nguyên tắc chung là tránh dựa vào các tác dụng phụ của Python trong logic của bạn và chỉ sử dụng chúng để gỡ lỗi các dấu vết của bạn. Mặt khác, các API TensorFlow như tf.data
, tf.print
, tf.summary
, tf.Variable.assign
và tf.TensorArray
là cách tốt nhất để đảm bảo mã của bạn sẽ được thực thi bởi thời gian chạy TensorFlow với mỗi lần gọi.
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Traced with 1 Executed with 1 Executed with 1 Traced with 2 Executed with 2
Nếu bạn muốn thực thi mã Python trong mỗi lần gọi một Function
, tf.py_function
Chức năng là một lối thoát. Hạn chế của tf.py_function
là nó không di động hoặc đặc biệt hiệu quả, không thể lưu bằng SavedModel và không hoạt động tốt trong các thiết lập phân tán (đa GPU, TPU). Ngoài ra, vì tf.py_function
phải được kết nối với biểu đồ, nó chuyển tất cả các đầu vào / đầu ra thành tensor.
Thay đổi các biến toàn cầu và miễn phí của Python
Việc thay đổi các biến toàn cầu và miễn phí của Python được coi là một tác dụng phụ của Python, vì vậy nó chỉ xảy ra trong quá trình truy tìm.
external_list = []
@tf.function
def side_effect(x):
print('Python side effect')
external_list.append(x)
side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect
Đôi khi những hành vi bất ngờ rất khó nhận thấy. Trong ví dụ dưới đây, bộ counter
được dùng để bảo vệ số gia của một biến. Tuy nhiên, vì nó là một số nguyên python và không phải là một đối tượng TensorFlow, nên giá trị của nó được ghi lại trong lần theo dõi đầu tiên. Khi hàm tf.function
được sử dụng, assign_add
sẽ được ghi lại vô điều kiện trong biểu đồ bên dưới. Do đó v
sẽ tăng 1, mỗi khi hàm tf.function
được gọi. Sự cố này thường xảy ra ở những người dùng cố gắng di chuyển mã Tensorflow ở chế độ Grpah của họ sang Tensorflow 2 bằng cách sử dụng trình trang trí tf.function
, khi các tác dụng phụ của python (bộ counter
trong ví dụ) được sử dụng để xác định các hoạt động sẽ chạy (trong ví dụ là assign_add
). Thông thường, người dùng chỉ nhận ra điều này sau khi nhìn thấy kết quả số đáng ngờ hoặc hiệu suất thấp hơn đáng kể so với mong đợi (ví dụ: nếu hoạt động được bảo vệ rất tốn kém).
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# A python side-effect
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 2, 3
1 2 3
Một giải pháp để đạt được hành vi mong đợi là sử dụng tf.init_scope
để nâng các hoạt động bên ngoài đồ thị hàm số. Điều này đảm bảo rằng việc tăng biến chỉ được thực hiện một lần trong thời gian truy tìm. Cần lưu ý rằng init_scope
có các tác dụng phụ khác bao gồm luồng điều khiển bị xóa và băng gradient. Đôi khi việc sử dụng init_scope
có thể trở nên quá phức tạp để quản lý một cách thực tế.
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 1, 1
1 1 1
Tóm lại, theo nguyên tắc chung, bạn nên tránh làm thay đổi các đối tượng python như số nguyên hoặc vùng chứa như danh sách nằm bên ngoài Function
. Thay vào đó, hãy sử dụng các đối số và đối tượng TF. Ví dụ: phần "Tích lũy các giá trị trong một vòng lặp" có một ví dụ về cách các hoạt động giống danh sách có thể được triển khai.
Trong một số trường hợp, bạn có thể nắm bắt và thao tác trạng thái nếu nó là tf.Variable
. Đây là cách các trọng số của các mô hình Keras được cập nhật với các lệnh gọi lặp lại đến cùng một ConcreteFunction
.
Sử dụng trình tạo và trình tạo Python
Nhiều tính năng của Python, chẳng hạn như trình tạo và trình vòng lặp, dựa vào thời gian chạy Python để theo dõi trạng thái. Nói chung, trong khi các cấu trúc này hoạt động như mong đợi ở chế độ háo hức, chúng là ví dụ về các tác dụng phụ của Python và do đó chỉ xảy ra trong quá trình truy tìm.
@tf.function
def buggy_consume_next(iterator):
tf.print("Value:", next(iterator))
iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1 Value: 1 Value: 1
Cũng giống như cách TensorFlow có tf.TensorArray
chuyên biệt cho các cấu trúc danh sách, nó có tf.data.Iterator
chuyên biệt cho các cấu trúc lặp. Xem phần về các phép biến đổi AutoGraph để biết tổng quan. Ngoài ra, API tf.data
có thể giúp triển khai các mẫu trình tạo:
@tf.function
def good_consume_next(iterator):
# This is ok, iterator is a tf.data.Iterator
tf.print("Value:", next(iterator))
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1 Value: 2 Value: 3
Tất cả các đầu ra của hàm tf. phải là giá trị trả về
Ngoại trừ tf.Variable
s, một hàm tf. phải trả về tất cả các đầu ra của nó. Việc cố gắng truy cập trực tiếp vào bất kỳ tensor nào từ một hàm mà không đi qua các giá trị trả về gây ra "rò rỉ".
Ví dụ: hàm bên dưới "rò rỉ" tensor a
qua Python global x
:
x = None
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
3 'Tensor' object has no attribute 'numpy'
Điều này đúng ngay cả khi giá trị bị rò rỉ cũng được trả về:
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
@tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
return b
with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
2 'Tensor' object has no attribute 'numpy' Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/566849597.py", line 21, in <module> captures_leaked_tensor(tf.constant(2)) TypeError: Originated from a graph execution error. The graph execution error is detected at a node built at (most recent call last): >>> File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main >>> File /usr/lib/python3.7/runpy.py, line 85, in _run_code >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module> >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start >>> File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever >>> File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once >>> File /usr/lib/python3.7/asyncio/events.py, line 88, in _run >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code >>> File /tmp/ipykernel_26244/566849597.py, line 7, in <module> >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__ >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler >>> File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2 >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__ Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.
Thông thường, những rò rỉ như thế này xảy ra khi bạn sử dụng các câu lệnh hoặc cấu trúc dữ liệu Python. Ngoài việc rò rỉ các tensor không truy cập được, các câu lệnh như vậy cũng có khả năng sai vì chúng được tính là tác dụng phụ của Python và không được đảm bảo thực thi ở mọi lệnh gọi hàm.
Các cách phổ biến để làm rò rỉ các tensors cục bộ cũng bao gồm việc thay đổi một bộ sưu tập Python bên ngoài hoặc một đối tượng:
class MyClass:
def __init__(self):
self.field = None
external_list = []
external_object = MyClass()
def leaky_function():
a = tf.constant(1)
external_list.append(a) # Bad - leaks tensor
external_object.field = a # Bad - leaks tensor
Các hàm tf.funcive đệ quy không được hỗ trợ
Function
đệ quy không được hỗ trợ và có thể gây ra vòng lặp vô hạn. Ví dụ,
@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1
with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Caught expected exception <class 'Exception'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module> recursive_fn(tf.constant(5)) # Bad - maximum recursion error. tensorflow.python.autograph.impl.api.StagingError: in user code: File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn * if n > 0: File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__ return _abc_instancecheck(cls, instance) RecursionError: maximum recursion depth exceeded while calling a Python object
Ngay cả khi một Function
đệ quy có vẻ hoạt động, hàm python sẽ được theo dõi nhiều lần và có thể có hàm ý về hiệu suất. Ví dụ,
@tf.function
def recursive_fn(n):
if n > 0:
print('tracing')
return recursive_fn(n - 1)
else:
return 1
recursive_fn(5) # Warning - multiple tracings
tracing tracing tracing tracing tracing <tf.Tensor: shape=(), dtype=int32, numpy=1>
Các vấn đề đã biết
Nếu Function
của bạn đánh giá không chính xác, lỗi có thể được giải thích bởi các vấn đề đã biết này được lên kế hoạch khắc phục trong tương lai.
Tùy thuộc vào các biến toàn cầu và miễn phí của Python
Function
tạo một ConcreteFunction
mới khi được gọi với giá trị mới của một đối số Python. Tuy nhiên, nó không làm điều đó đối với bao đóng Python, toàn cầu hoặc không định vị của Function
đó. Nếu giá trị của chúng thay đổi giữa các lần gọi đến Function
, thì Function
sẽ vẫn sử dụng các giá trị mà chúng đã có khi nó được truy tìm. Điều này khác với cách các hàm Python thông thường hoạt động.
Vì lý do đó, bạn nên làm theo phong cách lập trình hàm sử dụng các đối số thay vì đóng trên các tên bên ngoài.
@tf.function
def buggy_add():
return 1 + foo
@tf.function
def recommended_add(foo):
return 1 + foo
foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add()) # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100! Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(101, shape=(), dtype=int32)
Một cách khác để cập nhật giá trị toàn cục là đặt nó thành tf.Variable
và thay vào đó sử dụng phương thức Variable.assign
.
@tf.function
def variable_add():
return 1 + foo
foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100! Variable: tf.Tensor(101, shape=(), dtype=int32)
Tùy thuộc vào các đối tượng Python
Khuyến nghị chuyển các đối tượng Python dưới dạng đối số vào tf.function
có một số vấn đề đã biết, dự kiến sẽ được khắc phục trong tương lai. Nói chung, bạn có thể dựa vào khả năng theo dõi nhất quán nếu bạn sử dụng cấu trúc nguyên thủy Python hoặc cấu trúc tương thích tf.nest
làm đối số hoặc chuyển trong một phiên bản khác của đối tượng vào một Function
. Tuy nhiên, Function
sẽ không tạo một dấu vết mới khi bạn truyền cùng một đối tượng và chỉ thay đổi các thuộc tính của nó .
class SimpleModel(tf.Module):
def __init__(self):
# These values are *not* tf.Variables.
self.bias = 0.
self.weight = 2.
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x)) # Didn't change :(
Adding bias! tf.Tensor(20.0, shape=(), dtype=float32)
Việc sử dụng cùng một Function
để đánh giá phiên bản cập nhật của mô hình sẽ có lỗi vì mô hình được cập nhật có khóa bộ nhớ cache giống với mô hình ban đầu.
Vì lý do đó, bạn nên viết Function
của mình để tránh phụ thuộc vào thuộc tính đối tượng có thể thay đổi hoặc tạo đối tượng mới.
Nếu không thể, một cách giải quyết là tạo các Function
mới mỗi khi bạn sửa đổi đối tượng của mình để buộc thực hiện lại:
def evaluate(model, x):
return model.weight * x + model.bias
new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
Vì việc thử lại có thể tốn kém , bạn có thể sử dụng tf.Variable
s làm thuộc tính đối tượng, có thể bị thay đổi (nhưng không được thay đổi, cẩn thận!) Để có hiệu ứng tương tự mà không cần truy xuất lại.
class BetterModel:
def __init__(self):
self.bias = tf.Variable(0.)
self.weight = tf.Variable(2.)
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5
print(evaluate(better_model, x)) # This works!
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
Tạo tf.Variables
Function
chỉ hỗ trợ singleton tf.Variable
được tạo một lần trong lần gọi đầu tiên và được sử dụng lại trong các lần gọi hàm tiếp theo. Đoạn mã bên dưới sẽ tạo một tf.Variable
mới trong mọi lệnh gọi hàm, dẫn đến ngoại lệ ValueError
.
Thí dụ:
@tf.function
def f(x):
v = tf.Variable(1.0)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module> f(1.0) ValueError: in user code: File "/tmp/ipykernel_26244/3018268426.py", line 3, in f * v = tf.Variable(1.0) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
Một mẫu phổ biến được sử dụng để khắc phục hạn chế này là bắt đầu bằng giá trị Không có trong Python, sau đó tạo tf.Variable
có điều kiện nếu giá trị là Không:
class Count(tf.Module):
def __init__(self):
self.count = None
@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)
return self.count.assign_add(1)
c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32)
Sử dụng với nhiều trình tối ưu hóa Keras
Bạn có thể gặp phải ValueError: tf.function only supports singleton tf.Variables created on the first call.
khi sử dụng nhiều hơn một trình tối ưu hóa Keras có tf.function
. Lỗi này xảy ra do trình tối ưu hóa tạo tf.Variables
bên trong khi chúng áp dụng gradient lần đầu tiên.
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
@tf.function
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
train_step(w, x, y, opt2)
Calling `train_step` with different optimizer... Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module> train_step(w, x, y, opt2) ValueError: in user code: File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step * optimizer.apply_gradients(zip(gradients, [w])) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients ** self._create_all_weights(var_list) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights _ = self.iterations File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__ return super(OptimizerV2, self).__getattribute__(name) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight aggregation=aggregation) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable shape=variable_shape if variable_shape else None) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
Nếu bạn cần thay đổi trình tối ưu hóa trong quá trình đào tạo, cách giải quyết là tạo một Function
mới cho mỗi trình tối ưu hóa, gọi trực tiếp ConcreteFunction
.
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
# Not a tf.function.
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
if i % 2 == 0:
train_step_1(w, x, y) # `opt1` is not used as a parameter.
else:
train_step_2(w, x, y) # `opt2` is not used as a parameter.
Sử dụng với nhiều mô hình Keras
Bạn cũng có thể gặp phải ValueError: tf.function only supports singleton tf.Variables created on the first call.
khi chuyển các thể hiện mô hình khác nhau cho cùng một Function
.
Lỗi này xảy ra do các mô hình Keras ( không có hình dạng đầu vào được xác định ) và các lớp Keras tạo tf.Variables
khi chúng được gọi lần đầu tiên. Bạn có thể đang cố gắng khởi tạo các biến đó bên trong một Function
, đã được gọi. Để tránh lỗi này, hãy thử gọi model.build(input_shape)
để khởi tạo tất cả các trọng số trước khi huấn luyện mô hình.
đọc thêm
Để tìm hiểu về cách xuất và tải một Function
, hãy xem hướng dẫn SavedModel . Để tìm hiểu thêm về các tối ưu hóa biểu đồ được thực hiện sau khi theo dõi, hãy xem hướng dẫn Grappler . Để tìm hiểu cách tối ưu hóa đường ống dữ liệu và lập hồ sơ cho mô hình của bạn, hãy xem hướng dẫn về Hồ sơ.