Warning: This project is deprecated. Swift for TensorFlow was an experiment in the
next-generation platform for machine learning, incorporating the latest research across
machine learning, compilers, differentiable programming, systems design, and beyond. It was
archived in February 2021.
Embedding
Stay organized with collections
Save and categorize content based on your preferences.
An embedding layer.
Embedding
is effectively a lookup table that maps indices from a fixed vocabulary to fixed-size
(dense) vector representations, e.g. [[0], [3]] -> [[0.25, 0.1], [0.6, -0.2]]
.
-
A learnable lookup table that maps vocabulary indices to their dense vector representations.
Declaration
public var embeddings: Tensor<Scalar>
-
Creates an Embedding
layer with randomly initialized embeddings of shape
(vocabularySize, embeddingSize)
so that each vocabulary index is given a vector
representation.
Declaration
public init(
vocabularySize: Int,
embeddingSize: Int,
embeddingsInitializer: ParameterInitializer<Scalar> = { Tensor(randomUniform: $0) }
)
Parameters
vocabularySize
|
The number of distinct indices (words) in the vocabulary. This number
should be the largest integer index plus one.
|
embeddingSize
|
The number of entries in a single embedding vector representation.
|
embeddingsInitializer
|
Initializer to use for the embedding parameters.
|
-
Creates an Embedding
layer from the provided embeddings. Useful for introducing
pretrained embeddings into a model.
Declaration
public init(embeddings: Tensor<Scalar>)
Parameters
embeddings
|
The pretrained embeddings table.
|
-
Returns an output by replacing each index in the input with corresponding dense vector representation.
Declaration
@differentiable(wrt: self)
public func forward(_ input: Tensor<Int32>) -> Tensor<Scalar>
Return Value
The tensor created by replacing input indices with their vector representations.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2021-09-28 UTC.
[{
"type": "thumb-down",
"id": "missingTheInformationINeed",
"label":"Missing the information I need"
},{
"type": "thumb-down",
"id": "tooComplicatedTooManySteps",
"label":"Too complicated / too many steps"
},{
"type": "thumb-down",
"id": "outOfDate",
"label":"Out of date"
},{
"type": "thumb-down",
"id": "samplesCodeIssue",
"label":"Samples / code issue"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Other"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Easy to understand"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Solved my problem"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Other"
}]
{"lastModified": "Last updated 2021-09-28 UTC."}
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2021-09-28 UTC."]]