Use TensorFlow QueueRunner to read data from a binary file according to some conditions

#!/usr/local/bin/python
# -- coding: utf-8 --
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import time

import math
import numpy as np
from six.moves import xrange
import tensorflow as tf
import csv

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer(‘IMAGE_SIZE’, 40, “Size of input Image.”)
tf.app.flags.DEFINE_integer(‘PIXEL_DATA_SIZE’, 4, “Size of Image pixel.”)
tf.app.flags.DEFINE_integer(‘CHANNEL_NUMBER’, 1, “Size of input Image.”)
tf.app.flags.DEFINE_integer(‘LABEL_NUMBER’, 2, “Label number.”)
tf.app.flags.DEFINE_integer(‘BATCH_SIZE’, 128, “Size of a Batch.”)
tf.app.flags.DEFINE_integer(‘NUM_EPOCHS’, 10, “Number of epochs.”)
tf.app.flags.DEFINE_integer(‘EVAL_BATCH_SIZE’, 128, “Size of an Evalution Batch.”)
tf.app.flags.DEFINE_integer(‘SEED’, 66478, “Seed of Shuffle.”)
tf.app.flags.DEFINE_string(‘TOWER_NAME’, ‘JP’, “Name of tower.”)
tf.app.flags.DEFINE_integer(‘NUM_GPU’, 2, “How many GPUs to use.”)
tf.app.flags.DEFINE_integer(‘NUM_PREPROCESS_THREADS’, 8,
“Number of preprocessing threads.”)
tf.app.flags.DEFINE_string(
‘CSV_FILE’, ‘/home/kong/4T/official3D_110W/Shuffle.csv’, “Csv file path and name.”)
tf.app.flags.DEFINE_string(
‘BIN_FILE’, ‘/home/kong/4T/official3D_110W/shuffle3D64.bin’, “Bin file path and name.”)
tf.app.flags.DEFINE_string(‘XAVIER_INIT’,
‘tf.contrib.layers.xavier_initializer(seed=SEED)’,
“Initialize with XAVIER_INIT.”)

IMAGE_SIZE = FLAGS.IMAGE_SIZE
PIXEL_DATA_SIZE = FLAGS.PIXEL_DATA_SIZE
NUM_CHANNELS = FLAGS.CHANNEL_NUMBER
NUM_LABELS = FLAGS.LABEL_NUMBER
SEED = FLAGS.SEED
BATCH_SIZE = FLAGS.BATCH_SIZE
NUM_EPOCHS = FLAGS.NUM_EPOCHS
EVAL_BATCH_SIZE = FLAGS.EVAL_BATCH_SIZE
XAVIER_INIT = FLAGS.XAVIER_INIT
TOWER_NAME = FLAGS.TOWER_NAME
NUM_GPU = FLAGS.NUM_GPU
CSV_FILE = FLAGS.CSV_FILE
BIN_FILE = FLAGS.BIN_FILE
NUM_PREPROCESS_THREADS = FLAGS.NUM_PREPROCESS_THREADS

DTYPE = tf.float32
label_nums = 1
image_nums = 40**3

def init_bin_file(BIN_FILE):
bin_file_name = [BIN_FILE]
for f in bin_file_name:
if not tf.gfile.Exists(f):
raise ValueError(‘Failed to find file: ‘ + f)
filename_queue_bin = tf.train.string_input_producer(
bin_file_name, num_epochs=1)
label_nums = 1
image_nums = 40 40 40
record_bytes = (label_nums + image_nums) * 4
reader_bin = tf.FixedLengthRecordReader(record_bytes=record_bytes)
return filename_queue_bin, reader_bin

def init_csv_file(CSV_FILE):
csv_file_name = [CSV_FILE]
for f in csv_file_name:
if not tf.gfile.Exists(f):
raise ValueError(‘Failed to find file: ‘ + f)
filename_queue_csv = tf.train.string_input_producer(
csv_file_name, num_epochs=1)
reader_csv = tf.TextLineReader(skip_header_lines=True)
return filename_queue_csv, reader_csv

def get_data_without_no(filename_queue_bin, reader_bin, record_bytes, filename_queue_csv,
reader_csv, label_nums, image_nums, val_no, test_no):
def getBIN():
def getID():
key_raw, value = reader_csv.read(filename_queue_csv)
value_raw = tf.reshape(value, [1])
split_values = tf.string_split(value_raw, delimiter=’,’)
subsetid = tf.string_to_number(split_values.values[1], out_type=tf.int32)
return subsetid
key, value = reader_bin.read(filename_queue_bin)
record_bytes = tf.decode_raw(value, tf.float32)
label = tf.cast(tf.slice(record_bytes, [0], [label_nums]), tf.int64)
image = tf.reshape(tf.slice(record_bytes, [label_nums], [image_nums]),
shape=[40, 40, 40, 1])
return getID(), label, image
subsetid, label, image = getBIN()
cond = lambda subsetid, label, image: tf.logical_or(tf.equal(subsetid, tf.constant(
val_no, dtype=tf.int32)), tf.equal(subsetid, tf.constant(test_no, dtype=tf.int32)))
doRead = lambda subsetid, label, image: getBIN()
result = tf.while_loop(cond, doRead, [subsetid, label, image])
return result

def get_data_with_no(filename_queue_bin, reader_bin, record_bytes,
filename_queue_csv, reader_csv, label_nums, image_nums, no):
def getBIN():
def getID():
key_raw, value = reader_csv.read(filename_queue_csv)
value_raw = tf.reshape(value, [1])
split_values = tf.string_split(value_raw, delimiter=’,’)
subsetid = tf.string_to_number(split_values.values[1], out_type=tf.int32)
return subsetid
key, value = reader_bin.read(filename_queue_bin)
record_bytes = tf.decode_raw(value, tf.float32)
label = tf.cast(tf.slice(record_bytes, [0], [label_nums]), tf.int64)
image = tf.reshape(tf.slice(record_bytes, [label_nums], [image_nums]),
shape=[40, 40, 40, 1])
return getID(), label, image
subsetid, label, image = getBIN()
cond = lambda subsetid, label, image: tf.not_equal(
subsetid, tf.constant(no, dtype=tf.int32))
doRead = lambda subsetid, label, image: getBIN()
result = tf.while_loop(cond, doRead, [subsetid, label, image])
return result

def get_noaug_with_no(filename_queue_bin, reader_bin, record_bytes,
filename_queue_csv, reader_csv, label_nums, image_nums, no):
def getBIN():
def getID():
key_raw, value = reader_csv.read(filename_queue_csv)
value_raw = tf.reshape(value, [1])
split_values = tf.string_split(value_raw, delimiter=’,’)
subsetid = tf.string_to_number(split_values.values[1], out_type=tf.int32)
class_flag = tf.string_to_number(split_values.values[-2], out_type=tf.int32)
noaug = tf.string_to_number(split_values.values[-1], out_type=tf.int32)
return subsetid, class_flag, noaug
key, value = reader_bin.read(filename_queue_bin)
record_bytes = tf.decode_raw(value, tf.float32)
label = tf.cast(tf.slice(record_bytes, [0], [label_nums]), tf.int64)
image = tf.reshape(tf.slice(record_bytes, [label_nums], [image_nums]),
shape=[40, 40, 40, 1])
subsetid, class_flag, noaug = getID()
return subsetid, class_flag, noaug, label, image
subsetid, class_flag, noaug, label, image = getBIN()
cond = lambda subsetid, class_flag, noaug, label, image: tf.logical_or(tf.not_equal(subsetid, tf.constant(no, dtype=tf.int32)),
tf.logical_or(tf.not_equal(class_flag, tf.constant(1, dtype=tf.int32)), tf.not_equal(noaug, tf.constant(0, dtype=tf.int32))))
doRead = lambda subsetid, class_flag, noaug, label, image: getBIN()
result = tf.while_loop(cond, doRead, [subsetid, class_flag, noaug, label, image])
return result

def get_train_data(filename_queue_bin, reader_bin, record_bytes,
filename_queue_csv, reader_csv, label_nums, image_nums, val_no, test_no):
subsetid, label, image = get_data_without_no(filename_queue_bin, reader_bin, record_bytes,
filename_queue_csv, reader_csv, label_nums, image_nums, val_no, test_no)

min_queue_examples = BATCH_SIZE 20
sis, labels, images = tf.train.batch(
[subsetid, label, image],
batch_size=BATCH_SIZE,
num_threads=NUM_PREPROCESS_THREADS,
capacity=min_queue_examples + 3
BATCH_SIZE)
return sis, labels, images

def get_test_data(filename_queue_bin, reader_bin, record_bytes,
filename_queue_csv, reader_csv, label_nums, image_nums, no):
subsetid, label, image = get_data_with_no(filename_queue_bin, reader_bin, record_bytes,
filename_queue_csv, reader_csv, label_nums, image_nums, no)

min_queue_examples = BATCH_SIZE 20
sis, labels, images = tf.train.batch(
[subsetid, label, image],
batch_size=BATCH_SIZE,
num_threads=NUM_PREPROCESS_THREADS,
capacity=min_queue_examples + 3
BATCH_SIZE)
return sis, labels, images

def get_noaug_data(filename_queue_bin, reader_bin, record_bytes,
filename_queue_csv, reader_csv, label_nums, image_nums, no):
subsetid, class_flag, noaug, label, image = get_noaug_with_no(filename_queue_bin, reader_bin, record_bytes,
filename_queue_csv, reader_csv, label_nums, image_nums, no)
min_queue_examples = BATCH_SIZE 20
sis, cfs, noaugs, labels, images = tf.train.batch(
[subsetid, class_flag, noaug, label, image],
batch_size=BATCH_SIZE // 10,
num_threads=NUM_PREPROCESS_THREADS,
capacity=min_queue_examples + 3
BATCH_SIZE)
return sis, cfs, noaugs, labels, images

def main(_):
fqb, rb = init_bin_file(BIN_FILE)
fqc, rc = init_csv_file(CSV_FILE)
record_bytes = (1 + 40**3) * 4
sis, cfs, noaugs, labels, images = get_noaug_data(fqb, rb, record_bytes,
fqc, rc, label_nums, image_nums, 5)
with tf.Session() as sess:
sess.run(tf.initialize_local_variables())
sess.run(tf.initialize_all_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
sis_, cfs_, noaugs_, labels_, images_ = sess.run(
[sis, cfs, noaugs, labels, images])
print(sis_)
except tf.errors.OutOfRangeError:
print(‘Done training – epoch limit reached’)
finally:
coord.request_stop()
coord.join(threads)

if __name__ == ‘__main__‘:
tf.app.run()