Improve synchronized shuffle in datasets by ozabluda · Pull Request #8325 · keras-team/keras (original) (raw)
@@ -55,15 +55,15 @@ def load_data(path='imdb.npz', num_words=None, skip_top=0,
x_train, labels_train = f['x_train'], f['y_train']
x_test, labels_test = f['x_test'], f['y_test']
np.random.seed(seed)
np.random.shuffle(x_train)
np.random.seed(seed)
np.random.shuffle(labels_train)
np.random.seed(seed * 2)
np.random.shuffle(x_test)
np.random.seed(seed * 2)
np.random.shuffle(labels_test)
indices = np.arange(len(x_train))
np.random.shuffle(indices)
x_train = x_train[indices]
labels_train = labels_train[indices]
indices = np.arange(len(x_test))
np.random.shuffle(indices)
x_test = x_test[indices]
labels_test = labels_test[indices]
xs = np.concatenate([x_train, x_test])
labels = np.concatenate([labels_train, labels_test])