Source code for gradgraph.optimization.tf.layers

#!/usr/bin/env python3
# 
# layers.py
# 
# Created by Nicolas Fricker on 08/22/2025.
# 

from __future__ import annotations

import tensorflow as tf

[docs] @tf.keras.utils.register_keras_serializable() class Embedding(tf.keras.layers.Embedding): """ Embedding layer that inherits from `tensorflow.keras.layers.Embedding`. This layer maps integer indices to dense vectors of fixed size. It is typically used to transform categorical data, such as words, into continuous vectors for input into a neural network. Parameters ---------- input_dim : int Size of the vocabulary, i.e., maximum integer index + 1. output_dim : int Dimension of the dense embedding. *args : tuple Additional positional arguments passed to the parent class. **kwargs : dict Additional keyword arguments passed to the parent class. Methods ------- build(input_shape) Creates the layer's weights, ensuring that the embeddings are trainable if the layer is set to be trainable. Notes ----- This class extends `tensorflow.keras.layers.Embedding` and inherits its functionality. The `build` method is overridden to ensure that the embeddings' trainability is synchronized with the layer's trainable attribute. Examples -------- >>> import tensorflow as tf >>> embedding_layer = Embedding(input_dim=1000, output_dim=64) >>> input_data = tf.constant([[1, 2, 3], [4, 5, 6]]) >>> output = embedding_layer(input_data) >>> output.shape TensorShape([2, 3, 64]) """ def __init__(self, input_dim, output_dim, *args, **kwargs): super().__init__(input_dim, output_dim, *args, **kwargs)
[docs] def build(self, input_shape): super().build(input_shape) if hasattr(self, '_embeddings'): self._embeddings.trainable = self.trainable self._embeddings._is_local = getattr(self, '_is_local', False) self._embeddings.name = self.name
[docs] class BasePDESystemLayer(tf.keras.layers.Layer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def add_local_weight(self, name, **kwargs): v = self.add_weight(name=name, **kwargs) v._is_local = True return v
[docs] def add_global_weight(self, name, **kwargs): v = self.add_weight(name=name, **kwargs) v._is_local = False return v
[docs] def add_local_embedding(self, name, **kwargs): v = Embedding(name=name, **kwargs) v._is_local = True return v
[docs] def add_global_embedding(self, name, **kwargs): v = Embedding(name=name, **kwargs) v._is_local = False return v
@property def global_weights(self): return [w for w in self.weights if not getattr(w, '_is_local', False)] @property def global_trainable_weights(self): return [w for w in self.trainable_weights if not getattr(w, '_is_local', False)] @property def global_non_trainable_weights(self): return [w for w in self.non_trainable_weights if not getattr(w, '_is_local', False)] @property def local_weights(self): return [w for w in self.weights if getattr(w, '_is_local', False)] @property def local_trainable_weights(self): return [w for w in self.trainable_weights if getattr(w, '_is_local', False)] @property def local_non_trainable_weights(self): return [w for w in self.non_trainable_weights if getattr(w, '_is_local', False)]
[docs] def build(self, input_shape) -> None: # if self.built: # return # Create weights here using add_(local|global)_(weights|embeddings) # Build like this # ids = input_shape[0] # for v in self.local_weights: # _ = v.build(ids) super().build(input_shape)
[docs] def body(self, *args, **kwargs): raise NotImplementedError()
[docs] def cond(self, *args, **kwargs): raise NotImplementedError()
[docs] def call(self, inputs, training=False): # tf.while_loop(self.cond, self.body, loop_vars=..., parallel_iterations=1) raise NotImplementedError()
[docs] def compute_output_shape(self, input_shape): raise NotImplementedError()
[docs] def get_config(self): return super().get_config()