from keras.datasets import cifar10 (X_train, y_train), (X_test, y_test) = cifar10.load_data() # y_train.shape is 2d, (50000, 1). While Keras is smart enough to handle this # it's a good idea to flatten the array. y_train = y_train.reshape(-1) y_test = y_test.reshape(-1) from sklearn.model_selection import train_test_split X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.3, random_state=42, stratify = y_train) print('X_train shape = ', X_train.shape) print('X_test shape = ', X_test.shape) print('y_train shape = ', y_train.shape) print('y_test shape = ', y_test.shape)当然,CIFAR10 是 10 分类,与ImageNet 和 German Traffic Sign 的分类数目都不相同,我们要做适当的改变。
def load_bottleneck_data(training_file, validation_file): """ Utility function to load bottleneck features. Arguments: training_file - String validation_file - String """ print("Training file", training_file) print("Validation file", validation_file) with open(training_file, 'rb') as f: train_data = pickle.load(f) with open(validation_file, 'rb') as f: validation_data = pickle.load(f) X_train = train_data['features'] y_train = train_data['labels'] X_val = validation_data['features'] y_val = validation_data['labels'] return X_train, y_train, X_val, y_val
import pickle import tensorflow as tf import numpy as np from keras.layers import Input, Flatten, Dense from keras.models import Model flags = tf.app.flags FLAGS = flags.FLAGS # command line flags flags.DEFINE_string('training_file', '', "Bottleneck features training file (.p)") flags.DEFINE_string('validation_file', '', "Bottleneck features validation file (.p)") flags.DEFINE_integer('epochs', 50, "The number of epochs.") flags.DEFINE_integer('batch_size', 256, "The batch size.") def main(_): # load bottleneck data X_train, y_train, X_val, y_val = load_bottleneck_data(FLAGS.training_file, FLAGS.validation_file) print(X_train.shape, y_train.shape) print(X_val.shape, y_val.shape) nb_classes = len(np.unique(y_train)) # define model input_shape = X_train.shape[1:] inp = Input(shape=input_shape) x = Flatten()(inp) x = Dense(nb_classes, activation='softmax')(x) model = Model(inp, x) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # train model model.fit(X_train, y_train, epochs=FLAGS.epochs, batch_size=FLAGS.batch_size, validation_data=(X_val, y_val), shuffle=True) # parses flags and calls the `main` function above if __name__ == '__main__': tf.app.run()这里我们通过 tf.app 来定义了 命令行参数, 我们可以通过下面方式来跑