Xl python inference (#261) · apple/ml-stable-diffusion@f3a2124 (original) (raw)

`@@ -6,6 +6,7 @@

`

6

6

`import coremltools as ct

`

7

7

``

8

8

`import logging

`

``

9

`+

import json

`

9

10

``

10

11

`logging.basicConfig()

`

11

12

`logger = logging.getLogger(name)

`

`@@ -21,14 +22,47 @@ class CoreMLModel:

`

21

22

`""" Wrapper for running CoreML models using coremltools

`

22

23

` """

`

23

24

``

24

``

`-

def init(self, model_path, compute_unit):

`

25

``

`-

assert os.path.exists(model_path) and model_path.endswith(".mlpackage")

`

``

25

`+

def init(self, model_path, compute_unit, sources='packages'):

`

26

26

``

27

27

`logger.info(f"Loading {model_path}")

`

28

28

``

29

29

`start = time.time()

`

30

``

`-

self.model = ct.models.MLModel(

`

31

``

`-

model_path, compute_units=ct.ComputeUnit[compute_unit])

`

``

30

`+

if sources == 'packages':

`

``

31

`+

assert os.path.exists(model_path) and model_path.endswith(".mlpackage")

`

``

32

+

``

33

`+

self.model = ct.models.MLModel(

`

``

34

`+

model_path, compute_units=ct.ComputeUnit[compute_unit])

`

``

35

`+

DTYPE_MAP = {

`

``

36

`+

65552: np.float16,

`

``

37

`+

65568: np.float32,

`

``

38

`+

131104: np.int32,

`

``

39

`+

}

`

``

40

`+

self.expected_inputs = {

`

``

41

`+

input_tensor.name: {

`

``

42

`+

"shape": tuple(input_tensor.type.multiArrayType.shape),

`

``

43

`+

"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],

`

``

44

`+

}

`

``

45

`+

for input_tensor in self.model._spec.description.input

`

``

46

`+

}

`

``

47

`+

elif sources == 'compiled':

`

``

48

`+

assert os.path.exists(model_path) and model_path.endswith(".mlmodelc")

`

``

49

+

``

50

`+

self.model = ct.models.CompiledMLModel(model_path, ct.ComputeUnit[compute_unit])

`

``

51

+

``

52

`+

Grab expected inputs from metadata.json

`

``

53

`+

with open(os.path.join(model_path, 'metadata.json'), 'r') as f:

`

``

54

`+

config = json.load(f)[0]

`

``

55

+

``

56

`+

self.expected_inputs = {

`

``

57

`+

input_tensor['name']: {

`

``

58

`+

"shape": tuple(eval(input_tensor['shape'])),

`

``

59

`+

"dtype": np.dtype(input_tensor['dataType'].lower()),

`

``

60

`+

}

`

``

61

`+

for input_tensor in config['inputSchema']

`

``

62

`+

}

`

``

63

`+

else:

`

``

64

`` +

raise ValueError(f'Expected packages or compiled for sources, received {sources}')

``

``

65

+

32

66

`load_time = time.time() - start

`

33

67

`logger.info(f"Done. Took {load_time:.1f} seconds.")

`

34

68

``

`@@ -38,21 +72,6 @@ def init(self, model_path, compute_unit):

`

38

72

`"The Swift package we provide uses precompiled Core ML models (.mlmodelc) to avoid compile-on-load."

`

39

73

` )

`

40

74

``

41

``

-

42

``

`-

DTYPE_MAP = {

`

43

``

`-

65552: np.float16,

`

44

``

`-

65568: np.float32,

`

45

``

`-

131104: np.int32,

`

46

``

`-

}

`

47

``

-

48

``

`-

self.expected_inputs = {

`

49

``

`-

input_tensor.name: {

`

50

``

`-

"shape": tuple(input_tensor.type.multiArrayType.shape),

`

51

``

`-

"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],

`

52

``

`-

}

`

53

``

`-

for input_tensor in self.model._spec.description.input

`

54

``

`-

}

`

55

``

-

56

75

`def _verify_inputs(self, **kwargs):

`

57

76

`for k, v in kwargs.items():

`

58

77

`if k in self.expected_inputs:

`

`@@ -72,7 +91,7 @@ def _verify_inputs(self, **kwargs):

`

72

91

`f"Expected shape {expected_shape}, got {v.shape} for input: {k}"

`

73

92

` )

`

74

93

`else:

`

75

``

`-

raise ValueError("Received unexpected input kwarg: {k}")

`

``

94

`+

raise ValueError(f"Received unexpected input kwarg: {k}")

`

76

95

``

77

96

`def call(self, **kwargs):

`

78

97

`self._verify_inputs(**kwargs)

`

`@@ -82,21 +101,77 @@ def call(self, **kwargs):

`

82

101

`LOAD_TIME_INFO_MSG_TRIGGER = 10 # seconds

`

83

102

``

84

103

``

85

``

`-

def _load_mlpackage(submodule_name, mlpackages_dir, model_version,

`

86

``

`-

compute_unit):

`

87

``

`-

""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)

`

``

104

`+

def get_resource_type(resources_dir: str) -> str:

`

``

105

`+

"""

`

``

106

`+

Detect resource type based on filepath extensions.

`

``

107

`+

returns:

`

``

108

`` +

packages: for .mlpackage resources

``

``

109

`` +

'compiled`: for .mlmodelc resources

``

88

110

` """

`

89

``

`-

logger.info(f"Loading {submodule_name} mlpackage")

`

``

111

`+

directories = [f for f in os.listdir(resources_dir) if os.path.isdir(os.path.join(resources_dir, f))]

`

90

112

``

91

``

`-

fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(

`

92

``

`-

"/", "_")

`

93

``

`-

mlpackage_path = os.path.join(mlpackages_dir, fname)

`

``

113

`+

consider directories ending with extension

`

``

114

`+

extensions = set([os.path.splitext(e)[1] for e in directories if os.path.splitext(e)[1]])

`

94

115

``

95

``

`-

if not os.path.exists(mlpackage_path):

`

96

``

`-

raise FileNotFoundError(

`

97

``

`-

f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")

`

``

116

`+

if one extension present we may be able to infer sources type

`

``

117

`+

if len(set(extensions)) == 1:

`

``

118

`+

extension = extensions.pop()

`

``

119

`+

else:

`

``

120

`+

raise ValueError(f'Multiple file extensions found at {resources_dir}.'

`

``

121

`+

f'Cannot infer resource type from contents.')

`

``

122

+

``

123

`+

if extension == '.mlpackage':

`

``

124

`+

sources = 'packages'

`

``

125

`+

elif extension == '.mlmodelc':

`

``

126

`+

sources = 'compiled'

`

``

127

`+

else:

`

``

128

`+

raise ValueError(f'Did not find .mlpackage or .mlmodelc at {resources_dir}')

`

``

129

+

``

130

`+

return sources

`

``

131

+

``

132

+

``

133

`+

def _load_mlpackage(submodule_name,

`

``

134

`+

mlpackages_dir,

`

``

135

`+

model_version,

`

``

136

`+

compute_unit,

`

``

137

`+

sources=None):

`

``

138

`+

"""

`

``

139

`+

Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)

`

``

140

+

``

141

`+

"""

`

``

142

+

``

143

`` +

if sources not provided, attempt to infer packages or compiled from the

``

``

144

`+

resources directory

`

``

145

`+

if sources is None:

`

``

146

`+

sources = get_resource_type(mlpackages_dir)

`

``

147

+

``

148

`+

if sources == 'packages':

`

``

149

`+

logger.info(f"Loading {submodule_name} mlpackage")

`

``

150

`+

fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(

`

``

151

`+

"/", "_")

`

``

152

`+

mlpackage_path = os.path.join(mlpackages_dir, fname)

`

``

153

+

``

154

`+

if not os.path.exists(mlpackage_path):

`

``

155

`+

raise FileNotFoundError(

`

``

156

`+

f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")

`

``

157

+

``

158

`+

elif sources == 'compiled':

`

``

159

`+

logger.info(f"Loading {submodule_name} mlmodelc")

`

``

160

+

``

161

`+

FixMe: Submodule names and compiled resources names differ. Can change if names match in the future.

`

``

162

`+

submodule_names = ["text_encoder", "text_encoder_2", "unet", "vae_decoder"]

`

``

163

`+

compiled_names = ['TextEncoder', 'TextEncoder2', 'Unet', 'VAEDecoder', 'VAEEncoder']

`

``

164

`+

name_map = dict(zip(submodule_names, compiled_names))

`

``

165

+

``

166

`+

cname = name_map[submodule_name] + '.mlmodelc'

`

``

167

`+

mlpackage_path = os.path.join(mlpackages_dir, cname)

`

``

168

+

``

169

`+

if not os.path.exists(mlpackage_path):

`

``

170

`+

raise FileNotFoundError(

`

``

171

`+

f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")

`

``

172

+

``

173

`+

return CoreMLModel(mlpackage_path, compute_unit, sources=sources)

`

98

174

``

99

``

`-

return CoreMLModel(mlpackage_path, compute_unit)

`

100

175

``

101

176

`def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):

`

102

177

`""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)

`

`@@ -115,5 +190,6 @@ def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):

`

115

190

``

116

191

`return CoreMLModel(mlpackage_path, compute_unit)

`

117

192

``

``

193

+

118

194

`def get_available_compute_units():

`

119

195

`return tuple(cu for cu in ct.ComputeUnit.member_names)

`