Adding FID statistics calculation as an option (can now do "train", "eval", or "fid_stats") by AlexiaJM · Pull Request #5 · yang-song/score_sde (original) (raw)
Hey,
thanks for this code. I am trying to run --mode fid_stats on my score_sde_pytorch model but it does not seem to work. Is there some way to calculate the fid stats for a custom dataset that is trained with the score_sde_pytorch version?
Thanks!
I am using @AlexiaJM's fork with my custom config. I get the following error:
File "main.py", line 71, in <module>
app.run(main)
File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "main.py", line 65, in main
run_lib.fid_stats(FLAGS.config, FLAGS.fid_folder)
File "/content/score_sde/run_lib.py", line 609, in fid_stats
for batch_id in range(len(train_ds)):
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 454, in __len__
raise TypeError("dataset length is unknown.")
TypeError: dataset length is unknown.
Update
It seems to be a dataset problem. I am using this script to convert my images to tfrecords. I fixed the problem by catching the StopIteration exception once the generator has no next items.
batch_id = 0
while True:
try:
batch = next(bpd_iter)
except:
break
if jax.host_id() == 0:
logging.info("Making FID stats -- step: %d" % (batch_id))
batch_ = jax.tree_map(lambda x: x._numpy(), batch)
batch_ = (batch_['image']*255).astype(np.uint8).reshape((-1, config.data.image_size, config.data.image_size, 3))
# Force garbage collection before calling TensorFlow code for Inception network
gc.collect()
latents = evaluation.run_inception_distributed(batch_, inception_model,
inceptionv3=inceptionv3)
all_pools.append(latents["pool_3"])
# Force garbage collection again before returning to JAX code
gc.collect()
batch_id += 1
Feel free to correct me if I am wrong. I am not sure if this is the best/right solution.