bpo-42345: Fix three issues with typing.Literal parameters (GH-23294) · python/cpython@f03d318 (original) (raw)

`@@ -202,6 +202,20 @@ def _check_generic(cls, parameters, elen):

`

202

202

`f" actual {alen}, expected {elen}")

`

203

203

``

204

204

``

``

205

`+

def _deduplicate(params):

`

``

206

`+

Weed out strict duplicates, preserving the first of each occurrence.

`

``

207

`+

all_params = set(params)

`

``

208

`+

if len(all_params) < len(params):

`

``

209

`+

new_params = []

`

``

210

`+

for t in params:

`

``

211

`+

if t in all_params:

`

``

212

`+

new_params.append(t)

`

``

213

`+

all_params.remove(t)

`

``

214

`+

params = new_params

`

``

215

`+

assert not all_params, all_params

`

``

216

`+

return params

`

``

217

+

``

218

+

205

219

`def _remove_dups_flatten(parameters):

`

206

220

`"""An internal helper for Union creation and substitution: flatten Unions

`

207

221

` among parameters, then remove duplicates.

`

`@@ -215,38 +229,45 @@ def _remove_dups_flatten(parameters):

`

215

229

`params.extend(p[1:])

`

216

230

`else:

`

217

231

`params.append(p)

`

218

``

`-

Weed out strict duplicates, preserving the first of each occurrence.

`

219

``

`-

all_params = set(params)

`

220

``

`-

if len(all_params) < len(params):

`

221

``

`-

new_params = []

`

222

``

`-

for t in params:

`

223

``

`-

if t in all_params:

`

224

``

`-

new_params.append(t)

`

225

``

`-

all_params.remove(t)

`

226

``

`-

params = new_params

`

227

``

`-

assert not all_params, all_params

`

``

232

+

``

233

`+

return tuple(_deduplicate(params))

`

``

234

+

``

235

+

``

236

`+

def _flatten_literal_params(parameters):

`

``

237

`+

"""An internal helper for Literal creation: flatten Literals among parameters"""

`

``

238

`+

params = []

`

``

239

`+

for p in parameters:

`

``

240

`+

if isinstance(p, _LiteralGenericAlias):

`

``

241

`+

params.extend(p.args)

`

``

242

`+

else:

`

``

243

`+

params.append(p)

`

228

244

`return tuple(params)

`

229

245

``

230

246

``

231

247

`_cleanups = []

`

232

248

``

233

249

``

234

``

`-

def _tp_cache(func):

`

``

250

`+

def _tp_cache(func=None, /, *, typed=False):

`

235

251

`"""Internal wrapper caching getitem of generic types with a fallback to

`

236

252

` original function for non-hashable arguments.

`

237

253

` """

`

238

``

`-

cached = functools.lru_cache()(func)

`

239

``

`-

_cleanups.append(cached.cache_clear)

`

``

254

`+

def decorator(func):

`

``

255

`+

cached = functools.lru_cache(typed=typed)(func)

`

``

256

`+

_cleanups.append(cached.cache_clear)

`

240

257

``

241

``

`-

@functools.wraps(func)

`

242

``

`-

def inner(*args, **kwds):

`

243

``

`-

try:

`

244

``

`-

return cached(*args, **kwds)

`

245

``

`-

except TypeError:

`

246

``

`-

pass # All real errors (not unhashable args) are raised below.

`

247

``

`-

return func(*args, **kwds)

`

248

``

`-

return inner

`

``

258

`+

@functools.wraps(func)

`

``

259

`+

def inner(*args, **kwds):

`

``

260

`+

try:

`

``

261

`+

return cached(*args, **kwds)

`

``

262

`+

except TypeError:

`

``

263

`+

pass # All real errors (not unhashable args) are raised below.

`

``

264

`+

return func(*args, **kwds)

`

``

265

`+

return inner

`

249

266

``

``

267

`+

if func is not None:

`

``

268

`+

return decorator(func)

`

``

269

+

``

270

`+

return decorator

`

250

271

``

251

272

`def _eval_type(t, globalns, localns, recursive_guard=frozenset()):

`

252

273

`"""Evaluate all forward references in the given type t.

`

`@@ -319,6 +340,13 @@ def subclasscheck(self, cls):

`

319

340

`def getitem(self, parameters):

`

320

341

`return self._getitem(self, parameters)

`

321

342

``

``

343

+

``

344

`+

class _LiteralSpecialForm(_SpecialForm, _root=True):

`

``

345

`+

@_tp_cache(typed=True)

`

``

346

`+

def getitem(self, parameters):

`

``

347

`+

return self._getitem(self, parameters)

`

``

348

+

``

349

+

322

350

`@_SpecialForm

`

323

351

`def Any(self, parameters):

`

324

352

`"""Special type indicating an unconstrained type.

`

`@@ -436,7 +464,7 @@ def Optional(self, parameters):

`

436

464

`arg = _type_check(parameters, f"{self} requires a single type.")

`

437

465

`return Union[arg, type(None)]

`

438

466

``

439

``

`-

@_SpecialForm

`

``

467

`+

@_LiteralSpecialForm

`

440

468

`def Literal(self, parameters):

`

441

469

`"""Special typing form to define literal types (a.k.a. value types).

`

442

470

``

`@@ -460,7 +488,17 @@ def open_helper(file: str, mode: MODE) -> str:

`

460

488

` """

`

461

489

`# There is no '_type_check' call because arguments to Literal[...] are

`

462

490

`# values, not types.

`

463

``

`-

return _GenericAlias(self, parameters)

`

``

491

`+

if not isinstance(parameters, tuple):

`

``

492

`+

parameters = (parameters,)

`

``

493

+

``

494

`+

parameters = _flatten_literal_params(parameters)

`

``

495

+

``

496

`+

try:

`

``

497

`+

parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))

`

``

498

`+

except TypeError: # unhashable parameters

`

``

499

`+

pass

`

``

500

+

``

501

`+

return _LiteralGenericAlias(self, parameters)

`

464

502

``

465

503

``

466

504

`@_SpecialForm

`

`@@ -930,6 +968,21 @@ def subclasscheck(self, cls):

`

930

968

`return True

`

931

969

``

932

970

``

``

971

`+

def _value_and_type_iter(parameters):

`

``

972

`+

return ((p, type(p)) for p in parameters)

`

``

973

+

``

974

+

``

975

`+

class _LiteralGenericAlias(_GenericAlias, _root=True):

`

``

976

+

``

977

`+

def eq(self, other):

`

``

978

`+

if not isinstance(other, _LiteralGenericAlias):

`

``

979

`+

return NotImplemented

`

``

980

+

``

981

`+

return set(_value_and_type_iter(self.args)) == set(_value_and_type_iter(other.args))

`

``

982

+

``

983

`+

def hash(self):

`

``

984

`+

return hash(tuple(_value_and_type_iter(self.args)))

`

``

985

+

933

986

``

934

987

`class Generic:

`

935

988

`"""Abstract base class for generic types.

`