In addition to Weibo, there is also WeChat
Please pay attention
WeChat public account
Shulou
2025-04-03 Update From: SLTechnology News&Howtos shulou NAV: SLTechnology News&Howtos > Development >
Share
Shulou(Shulou.com)06/01 Report--
Editor to share with you what you need to pay attention to using the tensorflow2 custom loss function, I believe most people do not know much about it, so share this article for your reference, I hope you can learn a lot after reading this article, let's go to know it!
The core principle of Keras is to reveal complexity step by step, which can maintain the corresponding high-level convenience while having more control over the details of the operation. When we want to customize the training algorithm in fit, we can override the train_step method in the model, and then call fit to train the model.
Here is an example on the official website of tensorflow2:
Import numpy as npimport tensorflow as tffrom tensorflow import kerasx = np.random.random ((1000, 32)) y = np.random.random ((1000, 1)) class CustomModel (keras.Model): tf.random.set_seed (1000) def train_step (self, data): # Unpack the data. Its structure depends on your model and # on what you pass to `fit () `. X, y = data with tf.GradientTape () as tape: y_pred = self (x, training=True) # Forward pass # Compute the loss value # (the loss function is configured in `compile () `) loss = self.compiled_loss (y, y_pred, regularization_losses=self.losses) # Compute gradients trainable_vars = self.trainable_variables gradients = tape.gradient (loss Trainable_vars) # Update weights self.optimizer.apply_gradients (zip (gradients, trainable_vars)) # Update metrics (includes the metric that tracks the loss) self.compiled_metrics.update_state (y, y_pred) # Return a dict mapping metric names to current value return {m.name: m.result () for m in self.metrics} # Construct and compile an instance of CustomModelinputs = keras.Input (shape= (32 ) outputs = keras.layers.Dense (1) (inputs) model = CustomModel (inputs, outputs) model.compile (optimizer= "adam", loss=tf.losses.MSE, metrics= ["mae"]) # Just use `fit` as usualmodel.fit (x, y, epochs=1, shuffle=False) 32 [=]-0s 1ms/step-loss: 0.2783-mae: 0.4257
The loss here is the loss function implemented in the tensorflow library. If you want to customize the loss function and then transfer the loss function to the model.compile, can you follow the work we expected?
The answer turns out to be no, and there is no error prompt, but the loss calculation will not meet our expectations.
Def custom_mse (y_true, y_pred): return tf.reduce_mean ((y_true-y_pred) * * 2, axis=-1) a_true = tf.constant ([1.1,1.5,1.2]) a_pred = tf.constant ([1.1,2,1.5]) custom_mse (a_true, a_pred) tf.losses.MSE (a_true, a_pred)
The above results confirm the correctness of our custom loss. Let's put the custom loss directly into the loss parameter in compile and see what happens.
My_model = CustomModel (inputs, outputs) my_model.compile (optimizer= "adam", loss=custom_mse, metrics= ["mae"]) my_model.fit (x, y, epochs=1, shuffle=False) 32 + 32 [=]-0s 820us/step-loss: 0.1628-mae: 0.3257
We see that the loss here is significantly different from ours and the standard tf.losses.MSE. This shows that our custom loss is completely wrong to pass directly into model.compile in this way.
What is the correct pose for using custom loss? Here's the revelation.
Loss_tracker = keras.metrics.Mean (name= "loss") mae_metric = keras.metrics.MeanAbsoluteError (name= "mae") class MyCustomModel (keras.Model): tf.random.set_seed (100) def train_step (self, data): # Unpack the data. Its structure depends on your model and # on what you pass to `fit () `. X, y = data with tf.GradientTape () as tape: y_pred = self (x, training=True) # Forward pass # Compute the loss value # (the loss function is configured in `compile () `) loss = custom_mse (y, y_pred) # loss + = self.losses # Compute gradients trainable_vars = self.trainable_variables gradients = tape.gradient (loss Trainable_vars) # Update weights self.optimizer.apply_gradients (zip (gradients, trainable_vars)) # Compute our own metrics loss_tracker.update_state (loss) mae_metric.update_state (y, y_pred) return {"loss": loss_tracker.result () "mae": mae_metric.result ()} @ property def metrics (self): # We list our `Metric` objects here so that `reset_states () `can be # called automatically at the start of each epoch # or at the start of `evaluate () `. # If you don't implement this property, you have to call # `reset_states () `yourself at the time of your choosing. Return [loss_tracker, mae_metric] # Construct and compile an instance of CustomModelinputs = keras.Input (shape= (32,)) outputs = keras.layers.Dense (1) (inputs) my_model_beta = MyCustomModel (inputs, outputs) my_model_beta.compile (optimizer= "adam") # Just use `fit` as usualmy_model_beta.fit (x, y, epochs=1, shuffle=False) 32 Legend 32 [=]-0s 960us/step-loss: 0.2783-mae: 0.4257
Finally, by skipping passing the loss function in compile () and doing all the calculations manually in train_step, we get exactly the same output as the previous default tf.losses.MSE, which is what we want.
The above is all the content of this article "what to pay attention to when using tensorflow2 custom loss function". Thank you for reading! I believe we all have a certain understanding, hope to share the content to help you, if you want to learn more knowledge, welcome to follow the industry information channel!
Welcome to subscribe "Shulou Technology Information " to get latest news, interesting things and hot topics in the IT industry, and controls the hottest and latest Internet news, technology news and IT industry trends.
Views: 0
*The comments in the above article only represent the author's personal views and do not represent the views and positions of this website. If you have more insights, please feel free to contribute and share.
Continue with the installation of the previous hadoop.First, install zookooper1. Decompress zookoope
"Every 5-10 years, there's a rare product, a really special, very unusual product that's the most un
© 2024 shulou.com SLNews company. All rights reserved.