How To Use Tf.contrib.model_pruning On Mnist?
I'm struggling to use Tensorflow's pruning library and haven't found many helpful examples so I'm looking for help to prune a simple model trained on the MNIST dataset. If anyone
Solution 1:
The simplest pruning library example I could get working, figured I'd post it here in case it helps some other noobie who has a hard time with the documentation.
import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data
epochs = 250
batch_size = 55000# Entire training set# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batches = int(len(mnist.train.images) / batch_size)
# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])
# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 300)
logits = layers.masked_fully_connected(layer2, 10)
# Create global step variable (needed for pruning)
global_step = tf.train.get_or_create_global_step()
reset_global_step_op = tf.assign(global_step, 0)
# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))
# Training op, the global step is critical here, make sure it matches the one used in pruning later# running this operation increments the global_step
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step)
# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Get, Print, and Edit Pruning Hyperparameters
pruning_hparams = pruning.get_pruning_hparams()
print("Pruning Hyperparameters:", pruning_hparams)
# Change hyperparameters to meet our needs
pruning_hparams.begin_pruning_step = 0
pruning_hparams.end_pruning_step = 250
pruning_hparams.pruning_frequency = 1
pruning_hparams.sparsity_function_end_step = 250
pruning_hparams.target_sparsity = .9# Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.9)
prune_op = p.conditional_mask_update_op()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
# Train the model before pruning (optional)for epoch inrange(epochs):
for batch inrange(batches):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})
# Calculate Test Accuracy every 10 epochsif epoch % 10 == 0:
acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Un-pruned model step %d test accuracy %g" % (epoch, acc_print))
acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Pre-Pruning accuracy:", acc_print)
print("Sparsity of layers (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
# Reset the global step counter and begin pruning
sess.run(reset_global_step_op)
for epoch inrange(epochs):
for batch inrange(batches):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# Prune and retrain
sess.run(prune_op)
sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})
# Calculate Test Accuracy every 10 epochsif epoch % 10 == 0:
acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Pruned model step %d test accuracy %g" % (epoch, acc_print))
print("Weight sparsities:", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
# Print final accuracy
acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Final accuracy:", acc_print)
print("Final sparsity by layer (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
Solution 2:
Roman Nikishin requested code that could save the model, it's a slight extension to my original answer.
import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data
epochs = 250
batch_size = 55000# Entire training set
model_path_unpruned = "Model_Saves/Unpruned.ckpt"
model_path_pruned = "Model_Saves/Pruned.ckpt"# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batches = int(len(mnist.train.images) / batch_size)
# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])
# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 300)
logits = layers.masked_fully_connected(layer2, 10)
# Create global step variable (needed for pruning)
global_step = tf.train.get_or_create_global_step()
reset_global_step_op = tf.assign(global_step, 0)
# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))
# Training op, the global step is critical here, make sure it matches the one used in pruning later# running this operation increments the global_step
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step)
# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Get, Print, and Edit Pruning Hyperparameters
pruning_hparams = pruning.get_pruning_hparams()
print("Pruning Hyperparameters:", pruning_hparams)
# Change hyperparameters to meet our needs
pruning_hparams.begin_pruning_step = 0
pruning_hparams.end_pruning_step = 250
pruning_hparams.pruning_frequency = 1
pruning_hparams.sparsity_function_end_step = 250
pruning_hparams.target_sparsity = .9# Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.9)
prune_op = p.conditional_mask_update_op()
# Create a saver for writing training checkpoints.
saver = tf.train.Saver()
with tf.Session() as sess:
# Uncomment the following if you don't have a trained model yet
sess.run(tf.initialize_all_variables())
# Train the model before pruning (optional)for epoch inrange(epochs):
for batch inrange(batches):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})
# Calculate Test Accuracy every 10 epochsif epoch % 10 == 0:
acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Un-pruned model step %d test accuracy %g" % (epoch, acc_print))
acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Pre-Pruning accuracy:", acc_print)
print("Sparsity of layers (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
# Saves the model before pruning
saver.save(sess, model_path_unpruned)
# Resets the session and restores the saved model
sess.run(tf.initialize_all_variables())
saver.restore(sess, model_path_unpruned)
# Reset the global step counter and begin pruning
sess.run(reset_global_step_op)
for epoch inrange(epochs):
for batch inrange(batches):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# Prune and retrain
sess.run(prune_op)
sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})
# Calculate Test Accuracy every 10 epochsif epoch % 10 == 0:
acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Pruned model step %d test accuracy %g" % (epoch, acc_print))
print("Weight sparsities:", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
# Saves the model after pruning
saver.save(sess, model_path_pruned)
# Print final accuracy
acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Final accuracy:", acc_print)
print("Final sparsity by layer (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
Post a Comment for "How To Use Tf.contrib.model_pruning On Mnist?"