Adding FID statistics calculation as an option (can now do "train", "eval", or "fid_stats") by AlexiaJM · Pull Request #5 · yang-song/score_sde (original) (raw)

Hey,
thanks for this code. I am trying to run --mode fid_stats on my score_sde_pytorch model but it does not seem to work. Is there some way to calculate the fid stats for a custom dataset that is trained with the score_sde_pytorch version?

Thanks!

I am using @AlexiaJM's fork with my custom config. I get the following error:

  File "main.py", line 71, in <module>
    app.run(main)
  File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "main.py", line 65, in main
    run_lib.fid_stats(FLAGS.config, FLAGS.fid_folder)
  File "/content/score_sde/run_lib.py", line 609, in fid_stats
    for batch_id in range(len(train_ds)):
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 454, in __len__
    raise TypeError("dataset length is unknown.")
TypeError: dataset length is unknown.

Update
It seems to be a dataset problem. I am using this script to convert my images to tfrecords. I fixed the problem by catching the StopIteration exception once the generator has no next items.

  batch_id = 0

  while True:
    try:
      batch = next(bpd_iter)
    except:
      break

    if jax.host_id() == 0:
      logging.info("Making FID stats -- step: %d" % (batch_id))

    batch_ = jax.tree_map(lambda x: x._numpy(), batch)
    batch_ = (batch_['image']*255).astype(np.uint8).reshape((-1, config.data.image_size, config.data.image_size, 3))

    # Force garbage collection before calling TensorFlow code for Inception network
    gc.collect()
    latents = evaluation.run_inception_distributed(batch_, inception_model,
                                                   inceptionv3=inceptionv3)
    all_pools.append(latents["pool_3"])
    # Force garbage collection again before returning to JAX code
    gc.collect()
    batch_id += 1

Feel free to correct me if I am wrong. I am not sure if this is the best/right solution.