Julian Dolby
IBM Thomas J. Watson Research Center
PLDI PC Meeting Workshop, February 2018
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)
mnist.train.data
image array
Recognize digits from images
reshape
(line 42) needs original shapereshape
(line 42) for conv2d
(line 45)Information never explicit in the code
mnist : {
training : {images : [channel; x(28)*y(28); batch]},
test : {images : [channel; x(28)*y(28); batch]}
}
mnist = input_data.read_data_sets(...)
[channel; x(28)*y(28); batch] ->
[channel; 1; x(28); y(28); batch]
x = tf.reshape(x, shape=[-1, 28, 28, 1])
var : { ("field" : type)* }
[ type ; label?(size)* ]
def conv_net(x_dict, n_classes, dropout, reuse, is_training):
with tf.variable_scope('ConvNet', reuse=reuse):
x = x_dict['images']
def model_fn(features, labels, mode):
logits_train = conv_net(features, num_classes, dropout, reuse=False,
is_training=True)
model = tf.estimator.Estimator(model_fn)
input_fn = tf.estimator.inputs.numpy_input_fn(
x={'images': mnist.train.images}, y=mnist.train.labels,
batch_size=batch_size, num_epochs=None, shuffle=True)
model.train(input_fn, steps=num_steps)
model.train
ultimately calls model_fn