Xl python inference (#261) · apple/ml-stable-diffusion@f3a2124 (original) (raw)
`@@ -6,6 +6,7 @@
`
6
6
`import coremltools as ct
`
7
7
``
8
8
`import logging
`
``
9
`+
import json
`
9
10
``
10
11
`logging.basicConfig()
`
11
12
`logger = logging.getLogger(name)
`
`@@ -21,14 +22,47 @@ class CoreMLModel:
`
21
22
`""" Wrapper for running CoreML models using coremltools
`
22
23
` """
`
23
24
``
24
``
`-
def init(self, model_path, compute_unit):
`
25
``
`-
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")
`
``
25
`+
def init(self, model_path, compute_unit, sources='packages'):
`
26
26
``
27
27
`logger.info(f"Loading {model_path}")
`
28
28
``
29
29
`start = time.time()
`
30
``
`-
self.model = ct.models.MLModel(
`
31
``
`-
model_path, compute_units=ct.ComputeUnit[compute_unit])
`
``
30
`+
if sources == 'packages':
`
``
31
`+
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")
`
``
32
+
``
33
`+
self.model = ct.models.MLModel(
`
``
34
`+
model_path, compute_units=ct.ComputeUnit[compute_unit])
`
``
35
`+
DTYPE_MAP = {
`
``
36
`+
65552: np.float16,
`
``
37
`+
65568: np.float32,
`
``
38
`+
131104: np.int32,
`
``
39
`+
}
`
``
40
`+
self.expected_inputs = {
`
``
41
`+
input_tensor.name: {
`
``
42
`+
"shape": tuple(input_tensor.type.multiArrayType.shape),
`
``
43
`+
"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],
`
``
44
`+
}
`
``
45
`+
for input_tensor in self.model._spec.description.input
`
``
46
`+
}
`
``
47
`+
elif sources == 'compiled':
`
``
48
`+
assert os.path.exists(model_path) and model_path.endswith(".mlmodelc")
`
``
49
+
``
50
`+
self.model = ct.models.CompiledMLModel(model_path, ct.ComputeUnit[compute_unit])
`
``
51
+
``
52
`+
Grab expected inputs from metadata.json
`
``
53
`+
with open(os.path.join(model_path, 'metadata.json'), 'r') as f:
`
``
54
`+
config = json.load(f)[0]
`
``
55
+
``
56
`+
self.expected_inputs = {
`
``
57
`+
input_tensor['name']: {
`
``
58
`+
"shape": tuple(eval(input_tensor['shape'])),
`
``
59
`+
"dtype": np.dtype(input_tensor['dataType'].lower()),
`
``
60
`+
}
`
``
61
`+
for input_tensor in config['inputSchema']
`
``
62
`+
}
`
``
63
`+
else:
`
``
64
`` +
raise ValueError(f'Expected packages
or compiled
for sources, received {sources}')
``
``
65
+
32
66
`load_time = time.time() - start
`
33
67
`logger.info(f"Done. Took {load_time:.1f} seconds.")
`
34
68
``
`@@ -38,21 +72,6 @@ def init(self, model_path, compute_unit):
`
38
72
`"The Swift package we provide uses precompiled Core ML models (.mlmodelc) to avoid compile-on-load."
`
39
73
` )
`
40
74
``
41
``
-
42
``
`-
DTYPE_MAP = {
`
43
``
`-
65552: np.float16,
`
44
``
`-
65568: np.float32,
`
45
``
`-
131104: np.int32,
`
46
``
`-
}
`
47
``
-
48
``
`-
self.expected_inputs = {
`
49
``
`-
input_tensor.name: {
`
50
``
`-
"shape": tuple(input_tensor.type.multiArrayType.shape),
`
51
``
`-
"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],
`
52
``
`-
}
`
53
``
`-
for input_tensor in self.model._spec.description.input
`
54
``
`-
}
`
55
``
-
56
75
`def _verify_inputs(self, **kwargs):
`
57
76
`for k, v in kwargs.items():
`
58
77
`if k in self.expected_inputs:
`
`@@ -72,7 +91,7 @@ def _verify_inputs(self, **kwargs):
`
72
91
`f"Expected shape {expected_shape}, got {v.shape} for input: {k}"
`
73
92
` )
`
74
93
`else:
`
75
``
`-
raise ValueError("Received unexpected input kwarg: {k}")
`
``
94
`+
raise ValueError(f"Received unexpected input kwarg: {k}")
`
76
95
``
77
96
`def call(self, **kwargs):
`
78
97
`self._verify_inputs(**kwargs)
`
`@@ -82,21 +101,77 @@ def call(self, **kwargs):
`
82
101
`LOAD_TIME_INFO_MSG_TRIGGER = 10 # seconds
`
83
102
``
84
103
``
85
``
`-
def _load_mlpackage(submodule_name, mlpackages_dir, model_version,
`
86
``
`-
compute_unit):
`
87
``
`-
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
`
``
104
`+
def get_resource_type(resources_dir: str) -> str:
`
``
105
`+
"""
`
``
106
`+
Detect resource type based on filepath extensions.
`
``
107
`+
returns:
`
``
108
`` +
packages
: for .mlpackage resources
``
``
109
`` +
'compiled`: for .mlmodelc resources
``
88
110
` """
`
89
``
`-
logger.info(f"Loading {submodule_name} mlpackage")
`
``
111
`+
directories = [f for f in os.listdir(resources_dir) if os.path.isdir(os.path.join(resources_dir, f))]
`
90
112
``
91
``
`-
fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(
`
92
``
`-
"/", "_")
`
93
``
`-
mlpackage_path = os.path.join(mlpackages_dir, fname)
`
``
113
`+
consider directories ending with extension
`
``
114
`+
extensions = set([os.path.splitext(e)[1] for e in directories if os.path.splitext(e)[1]])
`
94
115
``
95
``
`-
if not os.path.exists(mlpackage_path):
`
96
``
`-
raise FileNotFoundError(
`
97
``
`-
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
`
``
116
`+
if one extension present we may be able to infer sources type
`
``
117
`+
if len(set(extensions)) == 1:
`
``
118
`+
extension = extensions.pop()
`
``
119
`+
else:
`
``
120
`+
raise ValueError(f'Multiple file extensions found at {resources_dir}.'
`
``
121
`+
f'Cannot infer resource type from contents.')
`
``
122
+
``
123
`+
if extension == '.mlpackage':
`
``
124
`+
sources = 'packages'
`
``
125
`+
elif extension == '.mlmodelc':
`
``
126
`+
sources = 'compiled'
`
``
127
`+
else:
`
``
128
`+
raise ValueError(f'Did not find .mlpackage or .mlmodelc at {resources_dir}')
`
``
129
+
``
130
`+
return sources
`
``
131
+
``
132
+
``
133
`+
def _load_mlpackage(submodule_name,
`
``
134
`+
mlpackages_dir,
`
``
135
`+
model_version,
`
``
136
`+
compute_unit,
`
``
137
`+
sources=None):
`
``
138
`+
"""
`
``
139
`+
Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
`
``
140
+
``
141
`+
"""
`
``
142
+
``
143
`` +
if sources not provided, attempt to infer packages
or compiled
from the
``
``
144
`+
resources directory
`
``
145
`+
if sources is None:
`
``
146
`+
sources = get_resource_type(mlpackages_dir)
`
``
147
+
``
148
`+
if sources == 'packages':
`
``
149
`+
logger.info(f"Loading {submodule_name} mlpackage")
`
``
150
`+
fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(
`
``
151
`+
"/", "_")
`
``
152
`+
mlpackage_path = os.path.join(mlpackages_dir, fname)
`
``
153
+
``
154
`+
if not os.path.exists(mlpackage_path):
`
``
155
`+
raise FileNotFoundError(
`
``
156
`+
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
`
``
157
+
``
158
`+
elif sources == 'compiled':
`
``
159
`+
logger.info(f"Loading {submodule_name} mlmodelc")
`
``
160
+
``
161
`+
FixMe: Submodule names and compiled resources names differ. Can change if names match in the future.
`
``
162
`+
submodule_names = ["text_encoder", "text_encoder_2", "unet", "vae_decoder"]
`
``
163
`+
compiled_names = ['TextEncoder', 'TextEncoder2', 'Unet', 'VAEDecoder', 'VAEEncoder']
`
``
164
`+
name_map = dict(zip(submodule_names, compiled_names))
`
``
165
+
``
166
`+
cname = name_map[submodule_name] + '.mlmodelc'
`
``
167
`+
mlpackage_path = os.path.join(mlpackages_dir, cname)
`
``
168
+
``
169
`+
if not os.path.exists(mlpackage_path):
`
``
170
`+
raise FileNotFoundError(
`
``
171
`+
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
`
``
172
+
``
173
`+
return CoreMLModel(mlpackage_path, compute_unit, sources=sources)
`
98
174
``
99
``
`-
return CoreMLModel(mlpackage_path, compute_unit)
`
100
175
``
101
176
`def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
`
102
177
`""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
`
`@@ -115,5 +190,6 @@ def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):
`
115
190
``
116
191
`return CoreMLModel(mlpackage_path, compute_unit)
`
117
192
``
``
193
+
118
194
`def get_available_compute_units():
`
119
195
`return tuple(cu for cu in ct.ComputeUnit.member_names)
`