How To Get String Value Out Of Tf.tensor Which Dtype Is String
Solution 1:
You can use tf.py_func
to wrap load_audio_file()
.
import tensorflow as tf
tf.enable_eager_execution()
def load_audio_file(file_path):
# you should decode bytes type to string type
print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
return file_path
train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))
for one_element in train_dataset:
print(one_element)
file_path: clean_4s_val/1.wav <class 'str'>
(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
file_path: clean_4s_val/3.wav <class 'str'>
(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
file_path: clean_4s_val/2.wav <class 'str'>
(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)
UPDATE for TF 2
The above solution will not work with TF 2 (tested with 2.2.0), even when replacing tf.py_func
with tf.py_function
, giving
InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'
To make it work in TF 2, make the following changes:
- Remove
tf.enable_eager_execution()
(eager is enabled by default in TF 2, which you can verify withtf.executing_eagerly()
returningTrue
) - Replace
tf.py_func
withtf.py_function
- Replace all in-function references of
file_path
withfile_path.numpy()
Solution 2:
If you want to do something completely custom, then wrapping your code in tf.py_function
is what you should do. Keep in mind that this will result in poor performance. See documentation and examples here:
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map
On the other hand if you are doing something generic, then you don't need to wrap your code in py_function
instead use any of the methods provided in tf.strings
module. These methods are made to work on string tensors and provide many common methods like split, join, len etc. These will not negatively effect performance, they will work on the tensor directly and return a modified tensor.
See documentation of tf.strings
here: https://www.tensorflow.org/api_docs/python/tf/strings
For example lets say you wanted to extract the name of the label from the file name you could then write code like this:
ds.map(lambda x: tf.strings.split(x, sep='$')[1])
The above assumes that the label is separated by a $
.
Post a Comment for "How To Get String Value Out Of Tf.tensor Which Dtype Is String"