feat: Optimize hub.py download by andi4191 · Pull Request #1022 · pytorch/TensorRT (original) (raw)
Reformatting /workspace/py/torch_tensorrt/logging.py Reformatting /workspace/py/torch_tensorrt/_Input.py Reformatting /workspace/py/torch_tensorrt/_Device.py Reformatting /workspace/py/torch_tensorrt/_enums.py Reformatting /workspace/py/torch_tensorrt/ptq.py Reformatting /workspace/py/torch_tensorrt/_util.py Reformatting /workspace/py/torch_tensorrt/_compile.py Reformatting /workspace/py/torch_tensorrt/init.py Reformatting /workspace/py/torch_tensorrt/ts/_compile_spec.py Reformatting /workspace/py/torch_tensorrt/ts/_compiler.py Reformatting /workspace/py/torch_tensorrt/ts/init.py Reformatting /workspace/py/setup.py --- /workspace/tests/modules/hub.py (original) +++ /workspace/tests/modules/hub.py (reformatted) @@ -88,6 +88,7 @@ def forward(self, x): return F.adaptive_avg_pool2d(x, (5, 5))
Sample Nested Module (for module-level fallback testing)
class ModuleFallbackSub(nn.Module):
@@ -98,6 +99,7 @@
def forward(self, x):
return self.relu(self.conv(x))
class ModuleFallbackMain(nn.Module):
@@ -110,6 +112,7 @@ def forward(self, x): return self.relu(self.conv(self.layer1(x)))
Sample Looping Modules (for loop fallback testing)
class LoopFallbackEval(nn.Module):
@@ -122,6 +125,7 @@ add_list = torch.cat((add_list, torch.tensor([x.shape[1]]).to(x.device)), 0) return x + add_list
class LoopFallbackNoEval(nn.Module):
def init(self):
@@ -131,6 +135,7 @@ for _ in range(x.shape[1]): x = x + torch.ones_like(x) return x +
Sample Conditional Model (for testing partitioning and fallback in conditionals)
class FallbackIf(torch.nn.Module): @@ -156,21 +161,23 @@ x = self.conv1(x) return x
- class ModelManifest:
- def init(self): self.version_matches = False if not os.path.exists(MANIFEST_FILE) or os.stat(MANIFEST_FILE).st_size == 0: self.manifest = {}
self.manifest.update({'version' : torch_version})
self.manifest.update({'version': torch_version}) else: with open(MANIFEST_FILE, 'r') as f: self.manifest = json.load(f) if self.manifest['version'] == torch_version: self.version_matches = True else:
print("Torch version: {} mismatches with manifest's version: {}. Re-downloading all models".format(torch_version, self.manifest['version']))
print("Torch version: {} mismatches with manifest's version: {}. Re-downloading all models".format(
torch_version, self.manifest['version'])) self.manifest["version"] = torch_version
def download(self, models): if self.version_matches:
@@ -194,13 +201,13 @@ record = json.dumps(manifest_record) f.write(record) f.truncate()
- def get_manifest(self): return self.manifest
- def if_version_matches(self): return self.version_matches
- def get(self, n, m): print("Downloading {}".format(n)) m["model"] = m["model"].eval().cuda()
@@ -214,8 +221,9 @@ if m["path"] == "both" or m["path"] == "script": script_model = torch.jit.script(m["model"]) torch.jit.save(script_model, script_filename)
self.manifest.update({n : [traced_filename, script_filename]})
self.manifest.update({n: [traced_filename, script_filename]})
def export_model(model, model_name, version_matches): if version_matches and os.path.exists(model_name): @@ -225,7 +233,7 @@ torch.jit.save(model, model_name)
-def generate_custom_models(manifest, matches = False): +def generate_custom_models(manifest, matches=False): # Pool model = Pool().eval().cuda() x = torch.ones([1, 3, 10, 10]).cuda() @@ -252,7 +260,8 @@ loop_fallback_no_eval_script_model = torch.jit.script(loop_fallback_no_eval_model) scripted_loop_fallback_no_eval_name = "loop_fallback_no_eval_scripted.jit.pt" export_model(loop_fallback_no_eval_script_model, scripted_loop_fallback_no_eval_name, matches)
- manifest.update({"torchtrt_loop_fallback_no_eval": [scripted_loop_fallback_name, scripted_loop_fallback_no_eval_name]})
- manifest.update(
{"torchtrt_loop_fallback_no_eval": [scripted_loop_fallback_name, scripted_loop_fallback_no_eval_name]})
conditional_model = FallbackIf().eval().cuda() Conditional
@@ -289,7 +298,7 @@ traced_bert_uncased_name = "bert_case_uncased_traced.jit.pt" traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) export_model(traced_model, traced_bert_uncased_name, matches)
- manifest.update({"torchtrt_bert_case_uncased" : [traced_bert_uncased_name]})
- manifest.update({"torchtrt_bert_case_uncased": [traced_bert_uncased_name]})
manifest = ModelManifest() Reformatting /workspace/tests/py/test_api_dla.py Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py Reformatting /workspace/tests/py/test_multi_gpu.py Reformatting /workspace/tests/py/test_trt_intercompatibility.py Reformatting /workspace/tests/py/model_test_case.py Reformatting /workspace/tests/py/test_qat_trt_accuracy.py Reformatting /workspace/tests/py/test_to_backend_api.py Reformatting /workspace/tests/modules/hub.py Reformatting /workspace/tests/py/test_api.py Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py Reformatting /workspace/tests/py/test_ptq_to_backend.py ERROR: Some files do not conform to style guidelines