Efficiently reading multiple files in Tensorflow 2

Note: Whether this method is efficient or not is contestable. Efficiency of a data input pipeline depends on many factors. How efficiently data are loaded? What is the computer architecture on which computations are being done? Is GPU available? And the list goes on. So readers might get different performance results when they use this method in their own problems. For the simple (and small) problem considered in this post, we got no perceivable performance improvement. But for one personal application, involving moderate size data (3-4 GB), I achieved 10x performance improvement. So I hope this method can be of help to others as well. The system on which we ran this notebook has 44 CPU cores. Tensorflow version was 2.4.0 and we did not use any GPU. Please note that for some weird reason, the speedup technique doesn’t work in Google Colab. But it works in GPU enabled personal systems, that I have checked.

Update: Along with the method described in this post, readers should also try using Tensorflow Sequence and see if it improves input pipeline efficiecy. Define all the complex transformations inside __getitem__ method of seqeuence class and then suitably choose max_queue_size, workers, and use_multiprocessing in model.fit() to improve pipeline efficiency.

View source on GitHub Download notebook

This post is a sequel to an older post. In the previous post, we discussed ways in which we can read multiple files in Tensorflow 2. If our aim is only to read files without doing any transformation on data, that method might work well for most applications. But if we need to make complex transformations on data before training our deep learning algorithm, the old method might turn out to be slow. In this post, we will describe a way in which we can speedup that process. The transformations that we will consider are spectrogram and normalizing (converting each value to a standard normal value). We have chosen these transformations just to illustrate the point. Readers can use any transformation (or no transformation) of their choice. More details regarding improving data performance can be found in this tensorflow guide.

As this post is a sequel, we expect readers to be familiar with the old post. We will not elaborate on points that have already been discussed. Rather, we will focus on section 4 which is the main topic of this post.

Outline:

  1. Create 500 ".csv" files and save it in the folder “random_data” in current directory.
  2. Write a generator that reads data from the folder in chunks and transforms it.
  3. Build data pipeline and train a CNN model.
  4. How to make the code run faster?
  5. How to make predictions?

1. Create 500 .csv files of random data

As we intend to train a CNN model for classification using our data, we will generate data for 5 different classes. Following is the process that we will follow.

  • Each .csv file will have one column of data with 1024 entries.
  • Each file will be saved using one of the following names (Fault_1, Fault_2, Fault_3, Fault_4, Fault_5). The dataset is balanced, meaning, for each category, we have approximately same number of observations. Data files in “Fault_1” category will have names as “Fault_1_001.csv”, “Fault_1_002.csv”, “Fault_1_003.csv”, …, “Fault_1_100.csv”. Similarly for other classes.
import numpy as np
import os
import glob
np.random.seed(1111)

First create a function that will generate random files.

def create_random_csv_files(fault_classes, number_of_files_in_each_class):
    os.mkdir("./random_data/")  # Make a directory to save created files.
    for fault_class in fault_classes:
        for i in range(number_of_files_in_each_class):
            data = np.random.rand(1024,)
            file_name = "./random_data/" + eval("fault_class") + "_" + "{0:03}".format(i+1) + ".csv" # This creates file_name
            np.savetxt(eval("file_name"), data, delimiter = ",", header = "V1", comments = "")
        print(str(eval("number_of_files_in_each_class")) + " " + eval("fault_class") + " files"  + " created.")

Now use the function to create 100 files each for five fault types.

create_random_csv_files(["Fault_1", "Fault_2", "Fault_3", "Fault_4", "Fault_5"], number_of_files_in_each_class = 100)
100 Fault_1 files created.
100 Fault_2 files created.
100 Fault_3 files created.
100 Fault_4 files created.
100 Fault_5 files created.
files = np.sort(glob.glob("./random_data/*"))
print("Total number of files: ", len(files))
print("Showing first 10 files...")
files[:10]
Total number of files:  500
Showing first 10 files...





array(['./random_data/Fault_1_001.csv', './random_data/Fault_1_002.csv',
       './random_data/Fault_1_003.csv', './random_data/Fault_1_004.csv',
       './random_data/Fault_1_005.csv', './random_data/Fault_1_006.csv',
       './random_data/Fault_1_007.csv', './random_data/Fault_1_008.csv',
       './random_data/Fault_1_009.csv', './random_data/Fault_1_010.csv'],
      dtype='<U29')

To extract labels from file name, extract the part of the file name that corresponds to fault type.

print(files[0])
./random_data/Fault_1_001.csv
print(files[0][14:21])
Fault_1

Now that data have been created, we will go to the next step. That is, define a generator, preprocess the time series like data into a matrix like shape such that a 2-D CNN can ingest it.

2. Write a generator that reads data in chunks and preprocesses it

These are the few things that we want our generator to have.

  1. It should run indefinitely, i.e., it is an infinite loop.
  2. Inside generator loop, read individual files using pandas.
  3. Do transformations on data if required.
  4. Yield the data.

As we will be solving a classification problem, we have to assign labels to each raw data. We will use following labels for convenience.

Class Label
Fault_1 0
Fault_2 1
Fault_3 2
Fault_4 3
Fault_5 4

The generator will yield both data and labels. The generator takes a list of file names as first argument. The second argument is batch_size.

import tensorflow as tf
print("Tensorflow Version: ", tf.__version__)
import pandas as pd
import re
Tensorflow Version:  2.4.0
def tf_data_generator(file_list, batch_size = 20):
    i = 0
    while True:    # This loop makes the generator an infinite loop
        if i*batch_size >= len(file_list):  
            i = 0
            np.random.shuffle(file_list)
        else:
            file_chunk = file_list[i*batch_size:(i+1)*batch_size] 
            data = []
            labels = []
            label_classes = tf.constant(["Fault_1", "Fault_2", "Fault_3", "Fault_4", "Fault_5"]) 
            for file in file_chunk:
                temp = pd.read_csv(open(file,'r')).astype(np.float32)    # Read data
                #########################################################################################################
                # Apply transformations. Comment this portion if you don't have to do any.
                # Try to use Tensorflow transformations as much as possible. First compute a spectrogram.
                temp = tf.math.abs(tf.signal.stft(tf.reshape(temp.values, shape = (1024,)),frame_length = 64, frame_step = 32, fft_length = 64))
                # After STFT transformation with given parameters, shape = (31,33)
                temp = tf.image.per_image_standardization(tf.reshape(temp, shape = (-1,31,33,1))) # Image Normalization
                ##########################################################################################################
                # temp = tf.reshape(temp, (32,32,1)) # Uncomment this line if you have not done any transformation.
                data.append(temp)
                pattern = tf.constant(eval("file[14:21]"))  
                for j in range(len(label_classes)):
                    if re.match(pattern.numpy(), label_classes[j].numpy()): 
                        labels.append(j)
            data = np.asarray(data).reshape(-1,31,33,1) 
            labels = np.asarray(labels)
            yield data, labels
            i = i + 1
batch_size = 15
dataset = tf.data.Dataset.from_generator(tf_data_generator,args= [files, batch_size],output_types = (tf.float32, tf.float32),
                                                output_shapes = ((None,31,33,1),(None,)))
for data, labels in dataset.take(7):
  print(data.shape)
  print(labels)
(15, 31, 33, 1)
tf.Tensor([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(15,), dtype=float32)
(15, 31, 33, 1)
tf.Tensor([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(15,), dtype=float32)
(15, 31, 33, 1)
tf.Tensor([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(15,), dtype=float32)
(15, 31, 33, 1)
tf.Tensor([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(15,), dtype=float32)
(15, 31, 33, 1)
tf.Tensor([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(15,), dtype=float32)
(15, 31, 33, 1)
tf.Tensor([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(15,), dtype=float32)
(15, 31, 33, 1)
tf.Tensor([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.], shape=(15,), dtype=float32)

The generator works fine. Now, we will train a full CNN model using the generator. As is done in every model, we will first shuffle data files. Split the files into train, validation, and test set. Using the tf_data_generator create three tensorflow datasets corresponding to train, validation, and test data respectively. Finally, we will create a simple CNN model. Train it using train dataset, see its performance on validation dataset, and obtain prediction using test dataset. Keep in mind that our aim is not to improve performance of the model. As the data are random, don’t expect to see good performance. The aim is only to create a pipeline.

3. Building data pipeline and training a CNN model

Before building the data pipeline, we will first move files corresponding to each fault class into different folders. This will make it convenient to split data into training, validation, and test set, keeping the balanced nature of the dataset intact.

import shutil

Create five different folders.

fault_folders = ["Fault_1", "Fault_2", "Fault_3", "Fault_4", "Fault_5"]
for folder_name in fault_folders:
    os.mkdir(os.path.join("./random_data", folder_name))

Move files into those folders.

for file in files:
    pattern = "^" + eval("file[14:21]")
    for j in range(len(fault_folders)):
        if re.match(pattern, fault_folders[j]):
            dest = os.path.join("./random_data/",eval("fault_folders[j]"))
            shutil.move(file, dest)
glob.glob("./random_data/*")
['./random_data/Fault_1',
 './random_data/Fault_2',
 './random_data/Fault_3',
 './random_data/Fault_4',
 './random_data/Fault_5']
np.sort(glob.glob("./random_data/Fault_1/*"))[:10] # Showing first 10 files of Fault_1 folder
array(['./random_data/Fault_1/Fault_1_001.csv',
       './random_data/Fault_1/Fault_1_002.csv',
       './random_data/Fault_1/Fault_1_003.csv',
       './random_data/Fault_1/Fault_1_004.csv',
       './random_data/Fault_1/Fault_1_005.csv',
       './random_data/Fault_1/Fault_1_006.csv',
       './random_data/Fault_1/Fault_1_007.csv',
       './random_data/Fault_1/Fault_1_008.csv',
       './random_data/Fault_1/Fault_1_009.csv',
       './random_data/Fault_1/Fault_1_010.csv'], dtype='<U37')
np.sort(glob.glob("./random_data/Fault_3/*"))[:10] # Showing first 10 files of Falut_3 folder
array(['./random_data/Fault_3/Fault_3_001.csv',
       './random_data/Fault_3/Fault_3_002.csv',
       './random_data/Fault_3/Fault_3_003.csv',
       './random_data/Fault_3/Fault_3_004.csv',
       './random_data/Fault_3/Fault_3_005.csv',
       './random_data/Fault_3/Fault_3_006.csv',
       './random_data/Fault_3/Fault_3_007.csv',
       './random_data/Fault_3/Fault_3_008.csv',
       './random_data/Fault_3/Fault_3_009.csv',
       './random_data/Fault_3/Fault_3_010.csv'], dtype='<U37')

Prepare that data for training set, validation set, and test_set. For each fault type, we will keep 70 files for training, 10 files for validation and 20 files for testing.

fault_1_files = glob.glob("./random_data/Fault_1/*")
fault_2_files = glob.glob("./random_data/Fault_2/*")
fault_3_files = glob.glob("./random_data/Fault_3/*")
fault_4_files = glob.glob("./random_data/Fault_4/*")
fault_5_files = glob.glob("./random_data/Fault_5/*")
from sklearn.model_selection import train_test_split
fault_1_train, fault_1_test = train_test_split(fault_1_files, test_size = 20, random_state = 5)
fault_2_train, fault_2_test = train_test_split(fault_2_files, test_size = 20, random_state = 54)
fault_3_train, fault_3_test = train_test_split(fault_3_files, test_size = 20, random_state = 543)
fault_4_train, fault_4_test = train_test_split(fault_4_files, test_size = 20, random_state = 5432)
fault_5_train, fault_5_test = train_test_split(fault_5_files, test_size = 20, random_state = 54321)
fault_1_train, fault_1_val = train_test_split(fault_1_train, test_size = 10, random_state = 1)
fault_2_train, fault_2_val = train_test_split(fault_2_train, test_size = 10, random_state = 12)
fault_3_train, fault_3_val = train_test_split(fault_3_train, test_size = 10, random_state = 123)
fault_4_train, fault_4_val = train_test_split(fault_4_train, test_size = 10, random_state = 1234)
fault_5_train, fault_5_val = train_test_split(fault_5_train, test_size = 10, random_state = 12345)
train_file_names = fault_1_train + fault_2_train + fault_3_train + fault_4_train + fault_5_train
validation_file_names = fault_1_val + fault_2_val + fault_3_val + fault_4_val + fault_5_val
test_file_names = fault_1_test + fault_2_test + fault_3_test + fault_4_test + fault_5_test

# Shuffle files
np.random.shuffle(train_file_names)
print("Number of train_files:" ,len(train_file_names))
print("Number of validation_files:" ,len(validation_file_names))
print("Number of test_files:" ,len(test_file_names))
Number of train_files: 350
Number of validation_files: 50
Number of test_files: 100
batch_size = 32
train_dataset = tf.data.Dataset.from_generator(tf_data_generator, args = [train_file_names, batch_size], 
                                              output_shapes = ((None,31,33,1),(None,)),
                                              output_types = (tf.float32, tf.float32))

validation_dataset = tf.data.Dataset.from_generator(tf_data_generator, args = [validation_file_names, batch_size],
                                                   output_shapes = ((None,31,33,1),(None,)),
                                                   output_types = (tf.float32, tf.float32))

test_dataset = tf.data.Dataset.from_generator(tf_data_generator, args = [test_file_names, batch_size],
                                             output_shapes = ((None,31,33,1),(None,)),
                                             output_types = (tf.float32, tf.float32))

Now create the model.

from tensorflow.keras import layers
model = tf.keras.Sequential([
    layers.Conv2D(16, 3, activation = "relu", input_shape = (31,33,1)),
    layers.MaxPool2D(2),
    layers.Conv2D(32, 3, activation = "relu"),
    layers.MaxPool2D(2),
    layers.Flatten(),
    layers.Dense(16, activation = "relu"),
    layers.Dense(5, activation = "softmax")
])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 29, 31, 16)        160       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 15, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 12, 13, 32)        4640      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 32)          0         
_________________________________________________________________
flatten (Flatten)            (None, 1152)              0         
_________________________________________________________________
dense (Dense)                (None, 16)                18448     
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 85        
=================================================================
Total params: 23,333
Trainable params: 23,333
Non-trainable params: 0
_________________________________________________________________

Compile the model.

model.compile(loss = "sparse_categorical_crossentropy", optimizer = "adam", metrics = ["accuracy"])

Before we fit the model, we have to do one important calculation. Remember that our generators are infinite loops. So if no stopping criteria is given, it will run indefinitely. But we want our model to run for, say, 10 epochs. So our generator should loop over the data files just 10 times and no more. This is achieved by setting the arguments steps_per_epoch and validation_steps to desired numbers in model.fit(). Similarly while evaluating model, we need to set the argument steps to a desired number in model.evaluate().

There are 350 files in training set. Batch_size is 10. So if the generator runs 35 times, it will correspond to one epoch. Therefor, we should set steps_per_epoch to 35. Similarly, validation_steps = 5 and in model.evaluate(), steps = 10.

steps_per_epoch = np.int(np.ceil(len(train_file_names)/batch_size))
validation_steps = np.int(np.ceil(len(validation_file_names)/batch_size))
steps = np.int(np.ceil(len(test_file_names)/batch_size))
print("steps_per_epoch = ", steps_per_epoch)
print("validation_steps = ", validation_steps)
print("steps = ", steps)
steps_per_epoch =  11
validation_steps =  2
steps =  4
model.fit(train_dataset, validation_data = validation_dataset, steps_per_epoch = steps_per_epoch,
         validation_steps = validation_steps, epochs = 5)
Epoch 1/5
11/11 [==============================] - 2s 202ms/step - loss: 1.6211 - accuracy: 0.1585 - val_loss: 1.6088 - val_accuracy: 0.1400
Epoch 2/5
11/11 [==============================] - 2s 164ms/step - loss: 1.6080 - accuracy: 0.2110 - val_loss: 1.6097 - val_accuracy: 0.2200
Epoch 3/5
11/11 [==============================] - 2s 164ms/step - loss: 1.6084 - accuracy: 0.1907 - val_loss: 1.6093 - val_accuracy: 0.1200
Epoch 4/5
11/11 [==============================] - 2s 163ms/step - loss: 1.6038 - accuracy: 0.2405 - val_loss: 1.6101 - val_accuracy: 0.1800
Epoch 5/5
11/11 [==============================] - 2s 162ms/step - loss: 1.6025 - accuracy: 0.2750 - val_loss: 1.6101 - val_accuracy: 0.1400





<tensorflow.python.keras.callbacks.History at 0x7f28cc0cfd60>
test_loss, test_accuracy = model.evaluate(test_dataset, steps = steps)
4/4 [==============================] - 0s 101ms/step - loss: 1.6099 - accuracy: 0.1900
print("Test loss: ", test_loss)
print("Test accuracy:", test_accuracy)
Test loss:  1.6099034547805786
Test accuracy: 0.1899999976158142

As expected, model performs terribly.

How to make the code run faster?

If no transformations are used, just using prefetch might improve performance. In deep learning usually GPUs are used for training. But all the data processing is done in CPU. In the naive approach, we will first process data in CPU, then send the processed data to GPU and after training finishes, we will prepare another batch of data. This approach is not efficient because GPU has to wait for data to get prepared. But using prefetch, we prepare and keep ready batches of data while training continues. In this way, waiting time of GPU is minimized.

When data transformations are used, out aim should always be to use parallel processing capabilities of tensorflow. We can achieve this using map function. Inside the map function, all transformations are defined. Then we can prefetch batches to further improve performance. The whole pipeline is as follows.

1. def transformation_function(...):
    # Define all transormations (STFT, Normalization, etc.)
    
2. def generator(...):
    
       # Read data
    
       # Call transformation_function using tf.data.Dataset.map so that it can parallelize operations.
    
       # Finally yield the processed data

3. Create tf.data.Dataset s.

4. Prefecth datasets.

5. Create model and train it.

We will use one extra library tensorflow_datasets that will allow us to switch from tf.dataset to numpy. If tensorflow_datasets is not installed in your system, use pip install tensorflow-datasets to install it and then run following codes.

import tensorflow_datasets as tfds
def data_transformation_func(data):
  transformed_data = tf.math.abs(tf.signal.stft(data,frame_length = 64, frame_step = 32, fft_length = 64))
  transformed_data = tf.image.per_image_standardization(tf.reshape(transformed_data, shape = (-1,31,33,1))) # Normalization
  return transformed_data
def tf_data_generator_new(file_list, batch_size = 4):
    i = 0
    while True:
        if i*batch_size >= len(file_list):  
            i = 0
            np.random.shuffle(file_list)
        else:
            file_chunk = file_list[i*batch_size:(i+1)*batch_size]
            data = []
            labels = []
            label_classes = tf.constant(["Fault_1", "Fault_2", "Fault_3", "Fault_4", "Fault_5"]) 
            for file in file_chunk:
                temp = pd.read_csv(open(file,'r')).astype(np.float32)    # Read data
                data.append(tf.reshape(temp.values, shape = (1,1024)))
                pattern = tf.constant(eval("file[22:29]"))
                for j in range(len(label_classes)):
                    if re.match(pattern.numpy(), label_classes[j].numpy()): 
                        labels.append(j)
                    
            data = np.asarray(data)
            labels = np.asarray(labels)
            first_dim = data.shape[0]
            # Create tensorflow dataset so that we can use `map` function that can do parallel computation.
            data_ds = tf.data.Dataset.from_tensor_slices(data)
            data_ds = data_ds.batch(batch_size = first_dim).map(data_transformation_func,
                                                                num_parallel_calls = tf.data.experimental.AUTOTUNE)
            # Convert the dataset to a generator and subsequently to numpy array
            data_ds = tfds.as_numpy(data_ds)   # This is where tensorflow-datasets library is used.
            data = np.asarray([data for data in data_ds]).reshape(first_dim,31,33,1)
            
            yield data, labels
            i = i + 1
train_file_names[:10]
['./random_data/Fault_3/Fault_3_045.csv',
 './random_data/Fault_1/Fault_1_032.csv',
 './random_data/Fault_1/Fault_1_025.csv',
 './random_data/Fault_2/Fault_2_013.csv',
 './random_data/Fault_3/Fault_3_053.csv',
 './random_data/Fault_1/Fault_1_087.csv',
 './random_data/Fault_5/Fault_5_053.csv',
 './random_data/Fault_4/Fault_4_019.csv',
 './random_data/Fault_3/Fault_3_034.csv',
 './random_data/Fault_2/Fault_2_044.csv']
train_file_names[0][22:29]
'Fault_3'
batch_size = 20
dataset_check = tf.data.Dataset.from_generator(tf_data_generator_new,args= [train_file_names, batch_size],output_types = (tf.float32, tf.float32),
                                                output_shapes = ((None,31,33,1),(None,)))
for data, labels in dataset_check.take(7):
  print(data.shape)
  print(labels)
(20, 31, 33, 1)
tf.Tensor([2. 0. 0. 1. 2. 0. 4. 3. 2. 1. 1. 0. 3. 3. 2. 3. 1. 4. 2. 4.], shape=(20,), dtype=float32)
(20, 31, 33, 1)
tf.Tensor([3. 1. 1. 3. 4. 4. 2. 3. 4. 3. 3. 0. 1. 2. 0. 3. 2. 2. 2. 4.], shape=(20,), dtype=float32)
(20, 31, 33, 1)
tf.Tensor([2. 3. 0. 2. 2. 4. 3. 0. 4. 1. 0. 0. 2. 0. 0. 1. 0. 3. 2. 1.], shape=(20,), dtype=float32)
(20, 31, 33, 1)
tf.Tensor([4. 2. 2. 2. 0. 3. 4. 2. 0. 1. 2. 2. 3. 4. 0. 4. 2. 0. 4. 4.], shape=(20,), dtype=float32)
(20, 31, 33, 1)
tf.Tensor([1. 0. 4. 4. 0. 1. 0. 4. 0. 2. 1. 4. 3. 2. 1. 4. 4. 2. 4. 3.], shape=(20,), dtype=float32)
(20, 31, 33, 1)
tf.Tensor([2. 2. 0. 1. 3. 2. 2. 2. 1. 3. 3. 4. 0. 1. 4. 1. 3. 2. 1. 3.], shape=(20,), dtype=float32)
(20, 31, 33, 1)
tf.Tensor([2. 1. 2. 2. 4. 4. 1. 0. 2. 2. 1. 2. 3. 0. 0. 2. 2. 0. 3. 3.], shape=(20,), dtype=float32)
batch_size = 32
train_dataset_new = tf.data.Dataset.from_generator(tf_data_generator_new, args = [train_file_names, batch_size], 
                                                  output_shapes = ((None,31,33,1),(None,)),
                                                  output_types = (tf.float32, tf.float32))

validation_dataset_new = tf.data.Dataset.from_generator(tf_data_generator_new, args = [validation_file_names, batch_size],
                                                       output_shapes = ((None,31,33,1),(None,)),
                                                       output_types = (tf.float32, tf.float32))

test_dataset_new = tf.data.Dataset.from_generator(tf_data_generator_new, args = [test_file_names, batch_size],
                                                 output_shapes = ((None,31,33,1),(None,)),
                                                 output_types = (tf.float32, tf.float32))

Prefetch datasets.

train_dataset_new = train_dataset_new.prefetch(buffer_size = tf.data.AUTOTUNE)
validation_dataset_new = validation_dataset_new.prefetch(buffer_size = tf.data.AUTOTUNE)
model.compile(loss = "sparse_categorical_crossentropy", optimizer = "adam", metrics = ["accuracy"])
model.fit(train_dataset_new, validation_data = validation_dataset_new, steps_per_epoch = steps_per_epoch,
         validation_steps = validation_steps, epochs = 5)
Epoch 1/5
11/11 [==============================] - 3s 226ms/step - loss: 1.6027 - accuracy: 0.1989 - val_loss: 1.6112 - val_accuracy: 0.1600
Epoch 2/5
11/11 [==============================] - 2s 214ms/step - loss: 1.5986 - accuracy: 0.2520 - val_loss: 1.6104 - val_accuracy: 0.2400
Epoch 3/5
11/11 [==============================] - 2s 200ms/step - loss: 1.5954 - accuracy: 0.3161 - val_loss: 1.6122 - val_accuracy: 0.1800
Epoch 4/5
11/11 [==============================] - 2s 209ms/step - loss: 1.5892 - accuracy: 0.3650 - val_loss: 1.6101 - val_accuracy: 0.1600
Epoch 5/5
11/11 [==============================] - 2s 196ms/step - loss: 1.5816 - accuracy: 0.2972 - val_loss: 1.6148 - val_accuracy: 0.1600





<tensorflow.python.keras.callbacks.History at 0x7f2888147940>
test_loss_new, test_acc_new = model.evaluate(test_dataset_new, steps = steps)
4/4 [==============================] - 1s 139ms/step - loss: 1.6089 - accuracy: 0.2000

How to make predictions?

In the generator used for prediction, we can also use map function to parallelize data preprocessing. But in practice, inference is much faster. So we can make fast predictions using naive method also. We show the naive implementation below.

def create_prediction_set(num_files = 20):
    os.mkdir("./random_data/prediction_set")
    for i in range(num_files):
        data = np.random.randn(1024,)
        file_name = "./random_data/prediction_set/"  + "file_" + "{0:03}".format(i+1) + ".csv" # This creates file_name
        np.savetxt(eval("file_name"), data, delimiter = ",", header = "V1", comments = "")
    print(str(eval("num_files")) + " "+ " files created in prediction set.")

Create some files for prediction set.

create_prediction_set(num_files = 55)
55  files created in prediction set.
prediction_files = glob.glob("./random_data/prediction_set/*")
print("Total number of files: ", len(prediction_files))
print("Showing first 10 files...")
prediction_files[:10]
Total number of files:  55
Showing first 10 files...





['./random_data/prediction_set/file_001.csv',
 './random_data/prediction_set/file_002.csv',
 './random_data/prediction_set/file_003.csv',
 './random_data/prediction_set/file_004.csv',
 './random_data/prediction_set/file_005.csv',
 './random_data/prediction_set/file_006.csv',
 './random_data/prediction_set/file_007.csv',
 './random_data/prediction_set/file_008.csv',
 './random_data/prediction_set/file_009.csv',
 './random_data/prediction_set/file_010.csv']

Now, we will create a generator to read these files in chunks. This generator will be slightly different from our previous generator. Firstly, we don’t want the generator to run indefinitely. Secondly, we don’t have any labels. So this generator should only yield data. This is how we achieve that.

def generator_for_prediction(file_list, batch_size = 20):
    i = 0
    while i <= (len(file_list)/batch_size):
        if i == np.floor(len(file_list)/batch_size):
            file_chunk = file_list[i*batch_size:len(file_list)]
            if len(file_chunk)==0:
                break
        else:
            file_chunk = file_list[i*batch_size:(i+1)*batch_size] 
        data = []
        for file in file_chunk:
            temp = pd.read_csv(open(file,'r')).astype(np.float32)
            temp = tf.math.abs(tf.signal.stft(tf.reshape(temp.values, shape = (1024,)),frame_length = 64, frame_step = 32, fft_length = 64))
            # After STFT transformation with given parameters, shape = (31,33)
            temp = tf.image.per_image_standardization(tf.reshape(temp, shape = (-1,31,33,1))) # Image Normalization
            data.append(temp) 
        data = np.asarray(data).reshape(-1,31,33,1)
        yield data
        i = i + 1

Check whether the generator works or not.

pred_gen = generator_for_prediction(prediction_files,  batch_size = 10)
for data in pred_gen:
    print(data.shape)
(10, 31, 33, 1)
(10, 31, 33, 1)
(10, 31, 33, 1)
(10, 31, 33, 1)
(10, 31, 33, 1)
(5, 31, 33, 1)

Create a tensorflow dataset.

batch_size = 10
prediction_dataset = tf.data.Dataset.from_generator(generator_for_prediction,args=[prediction_files, batch_size],
                                                 output_shapes=(None,31,33,1), output_types=(tf.float32))
steps = np.int(np.ceil(len(prediction_files)/batch_size))
predictions = model.predict(prediction_dataset,steps = steps)
print("Shape of prediction array: ", predictions.shape)
predictions
Shape of prediction array:  (55, 5)


array([[0.13616945, 0.22521223, 0.29032916, 0.11108191, 0.23720728],
       [0.13136739, 0.1767847 , 0.2776762 , 0.11521462, 0.29895717],
       [0.12264969, 0.17929009, 0.2746509 , 0.11005757, 0.3133517 ],
       [0.12557542, 0.14570946, 0.20385396, 0.15842426, 0.3664368 ],
       [0.13804483, 0.13592169, 0.24367407, 0.13454145, 0.347818  ],
       [0.1430458 , 0.18991745, 0.28873193, 0.11874774, 0.25955713],
       [0.13853352, 0.17857482, 0.31380644, 0.10534842, 0.26373684],
       [0.10823276, 0.22618511, 0.32452983, 0.0847131 , 0.25633916],
       [0.1383514 , 0.16129029, 0.25447774, 0.13601685, 0.30986372],
       [0.13583152, 0.1730804 , 0.25602627, 0.12829432, 0.30676743],
       [0.12959503, 0.1772274 , 0.30786148, 0.10242429, 0.28289178],
       [0.13454609, 0.16846487, 0.2601272 , 0.12760776, 0.30925405],
       [0.14779252, 0.14842029, 0.2833778 , 0.11843931, 0.3019701 ],
       [0.11896624, 0.21513633, 0.25535005, 0.11902266, 0.29152465],
       [0.13734229, 0.13935044, 0.2748529 , 0.11947   , 0.3289844 ],
       [0.15610486, 0.20303495, 0.30530566, 0.11581586, 0.21973862],
       [0.13397609, 0.15995616, 0.28893223, 0.11423217, 0.30290335],
       [0.12130069, 0.22576565, 0.2828214 , 0.10909654, 0.26101568],
       [0.15606783, 0.14581656, 0.25918248, 0.1374906 , 0.3014426 ],
       [0.13284522, 0.15171063, 0.23527463, 0.13531938, 0.34485018],
       [0.15039593, 0.18859874, 0.2730181 , 0.13246077, 0.25552657],
       [0.13006356, 0.23040107, 0.31713945, 0.09858881, 0.2238071 ],
       [0.14617579, 0.1553615 , 0.24506982, 0.14371207, 0.30968076],
       [0.15057516, 0.18175651, 0.26442483, 0.13344486, 0.26979864],
       [0.12942658, 0.17502321, 0.27020454, 0.11938909, 0.3059566 ],
       [0.1362109 , 0.18799251, 0.2874499 , 0.11758988, 0.27075684],
       [0.12986937, 0.20687391, 0.28071418, 0.11440149, 0.26814106],
       [0.13393444, 0.1739722 , 0.27028513, 0.12331051, 0.2984978 ],
       [0.12112842, 0.13956769, 0.22920072, 0.12982692, 0.3802763 ],
       [0.11119145, 0.23633307, 0.32426968, 0.08549411, 0.2427116 ],
       [0.13327955, 0.18379854, 0.2872899 , 0.11320265, 0.28242943],
       [0.12073855, 0.20085782, 0.2646106 , 0.11651796, 0.29727504],
       [0.11189438, 0.20137395, 0.27387396, 0.10702953, 0.30582818],
       [0.18017001, 0.16150263, 0.28068233, 0.1368001 , 0.2408449 ],
       [0.10944357, 0.17276171, 0.25993338, 0.10688126, 0.35098007],
       [0.13728923, 0.1559456 , 0.25643092, 0.13189963, 0.31843463],
       [0.15782082, 0.1793215 , 0.28856605, 0.12700985, 0.24728177],
       [0.13353582, 0.20542818, 0.32362464, 0.0972899 , 0.2401215 ],
       [0.1327661 , 0.19204186, 0.29048327, 0.11032179, 0.274387  ],
       [0.15101205, 0.16577183, 0.28014943, 0.12510163, 0.27796513],
       [0.10811005, 0.24937892, 0.2825413 , 0.09674085, 0.26322895],
       [0.12502007, 0.17934126, 0.23135227, 0.1389524 , 0.32533407],
       [0.15938489, 0.12479166, 0.2140554 , 0.16871263, 0.33305547],
       [0.13133633, 0.15853986, 0.2776162 , 0.11680949, 0.31569818],
       [0.13070984, 0.20629251, 0.32593974, 0.09547149, 0.24158639],
       [0.12578759, 0.13497958, 0.2329479 , 0.13350599, 0.37277895],
       [0.11535928, 0.18584532, 0.25516343, 0.11577264, 0.3278593 ],
       [0.12250951, 0.16624808, 0.22112629, 0.14213458, 0.34798154],
       [0.11505162, 0.22170952, 0.29335403, 0.09799453, 0.27189022],
       [0.15128227, 0.18352163, 0.26057395, 0.1377367 , 0.2668855 ],
       [0.10571367, 0.14169416, 0.2291365 , 0.12079947, 0.40265626],
       [0.11637849, 0.2011716 , 0.28354362, 0.10431051, 0.29459572],
       [0.12311503, 0.1520483 , 0.26735488, 0.11320169, 0.34428006],
       [0.13812302, 0.23593263, 0.27522495, 0.12361279, 0.22710668],
       [0.13761182, 0.17356406, 0.23145248, 0.14515112, 0.3122205 ]],
      dtype=float32)

Outputs of prediction are 5 dimensional vector. This is so because we have used 5 neurons in the output layer and our activation function is softmax. The 5 dimensional output vector for an input add to 1. So it can be interpreted as probability. Thus we should classify the input to a class, for which prediction probability is maximum. To get the class corresponding to maximum probability, we can use np.argmax() command.

np.argmax(predictions, axis = 1)
array([2, 4, 4, 4, 4, 2, 2, 2, 4, 4, 2, 4, 4, 4, 4, 2, 4, 2, 4, 4, 2, 2,
       4, 4, 4, 2, 2, 4, 4, 2, 2, 4, 4, 2, 4, 4, 2, 2, 2, 2, 2, 4, 4, 4,
       2, 4, 4, 4, 2, 4, 4, 4, 4, 2, 4])

As a final comment, read the note at the beginning of this post.

Biswajit Sahoo
Biswajit Sahoo
Machine Learning Engineer

My research interests include machine learning, deep learning, signal processing and data-driven machinery condition monitoring.

Related