PostgreSQL Source Code: src/backend/libpq/auth-oauth.c Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

17

19#include <fcntl.h>

20

32

33

35

37static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);

39 char **output, int *outputlen, const char **logdetail);

40

43

46

47

52

54};

55

56

58{

62};

63

64

66{

71};

72

77

78

79#define KVSEP 0x01

80#define AUTH_KEY "auth"

81#define BEARER_SCHEME "Bearer "

82

83

84

85

86

87

88static void

90{

91

94}

95

96

97

98

99

100

101static void *

103{

105

108 errcode(ERRCODE_PROTOCOL_VIOLATION),

109 errmsg("client selected an invalid SASL authentication mechanism"));

110

111 ctx = palloc0(sizeof(*ctx));

112

115

117 ctx->issuer = port->hba->oauth_issuer;

118 ctx->scope = port->hba->oauth_scope;

119

121

122 return ctx;

123}

124

125

126

127

128

129

130

131

132

133static int

135 char **output, int *outputlen, const char **logdetail)

136{

137 char *input_copy;

138 char *p;

139 char cbind_flag;

140 char *auth;

141 int status;

142

144

146 *outputlen = -1;

147

148

149

150

151

152

153

154 if (input == NULL)

155 {

157

159 *outputlen = 0;

161 }

162

163

164

165

166 if (inputlen == 0)

168 errcode(ERRCODE_PROTOCOL_VIOLATION),

169 errmsg("malformed OAUTHBEARER message"),

170 errdetail("The message is empty."));

171 if (inputlen != strlen(input))

173 errcode(ERRCODE_PROTOCOL_VIOLATION),

174 errmsg("malformed OAUTHBEARER message"),

175 errdetail("Message length does not match input length."));

176

177 switch (ctx->state)

178 {

180

181 break;

182

184

185

186

187

188

189 if (inputlen != 1 || *input != KVSEP)

191 errcode(ERRCODE_PROTOCOL_VIOLATION),

192 errmsg("malformed OAUTHBEARER message"),

193 errdetail("Client did not send a kvsep response."));

194

195

198

199 default:

200 elog(ERROR, "invalid OAUTHBEARER exchange state");

202 }

203

204

206

207

208

209

210

211

212

213

214 cbind_flag = *p;

215 switch (cbind_flag)

216 {

217 case 'p':

219 errcode(ERRCODE_PROTOCOL_VIOLATION),

220 errmsg("malformed OAUTHBEARER message"),

221 errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data."));

222 break;

223

224 case 'y':

225 case 'n':

226 p++;

227 if (*p != ',')

229 errcode(ERRCODE_PROTOCOL_VIOLATION),

230 errmsg("malformed OAUTHBEARER message"),

231 errdetail("Comma expected, but found character \"%s\".",

233 p++;

234 break;

235

236 default:

238 errcode(ERRCODE_PROTOCOL_VIOLATION),

239 errmsg("malformed OAUTHBEARER message"),

240 errdetail("Unexpected channel-binding flag \"%s\".",

242 }

243

244

245

246

247 if (*p == 'a')

249 errcode(ERRCODE_FEATURE_NOT_SUPPORTED),

250 errmsg("client uses authorization identity, but it is not supported"));

251 if (*p != ',')

253 errcode(ERRCODE_PROTOCOL_VIOLATION),

254 errmsg("malformed OAUTHBEARER message"),

255 errdetail("Unexpected attribute \"%s\" in client-first-message.",

257 p++;

258

259

262 errcode(ERRCODE_PROTOCOL_VIOLATION),

263 errmsg("malformed OAUTHBEARER message"),

264 errdetail("Key-value separator expected, but found character \"%s\".",

266 p++;

267

269 if (!auth)

271 errcode(ERRCODE_PROTOCOL_VIOLATION),

272 errmsg("malformed OAUTHBEARER message"),

273 errdetail("Message does not contain an auth value."));

274

275

276 if (*p)

278 errcode(ERRCODE_PROTOCOL_VIOLATION),

279 errmsg("malformed OAUTHBEARER message"),

280 errdetail("Message contains additional data after the final terminator."));

281

283 {

285

288 }

289 else

290 {

293 }

294

295

297

298 return status;

299}

300

301

302

303

304

305

306

307

308

309static char *

311{

312 static char buf[5];

313

314 if (c >= 0x21 && c <= 0x7E)

316 else

318 return buf;

319}

320

321

322

323

324

325

326static void

328{

329

330

331

332

333 static const char *key_allowed_set =

334 "abcdefghijklmnopqrstuvwxyz"

335 "ABCDEFGHIJKLMNOPQRSTUVWXYZ";

336

337 size_t span;

338

339 if (key[0])

341 errcode(ERRCODE_PROTOCOL_VIOLATION),

342 errmsg("malformed OAUTHBEARER message"),

343 errdetail("Message contains an empty key name."));

344

345 span = strspn(key, key_allowed_set);

346 if (key[span] != '\0')

348 errcode(ERRCODE_PROTOCOL_VIOLATION),

349 errmsg("malformed OAUTHBEARER message"),

350 errdetail("Message contains an invalid key name."));

351

352

353

354

355

356

357

358

360 {

361 if (0x21 <= *val && *val <= 0x7E)

362 continue;

363

364 switch (*val)

365 {

366 case ' ':

367 case '\t':

368 case '\r':

369 case '\n':

370 continue;

371

372 default:

374 errcode(ERRCODE_PROTOCOL_VIOLATION),

375 errmsg("malformed OAUTHBEARER message"),

376 errdetail("Message contains an invalid value."));

377 }

378 }

379}

380

381

382

383

384

385static char *

387{

388 char *pos = *input;

389 char *auth = NULL;

390

391

392

393

394

395

396

397

398

399

400

401

402

403

404

405 while (*pos)

406 {

407 char *end;

408 char *sep;

409 char *key;

411

412

413

414

415

416 end = strchr(pos, KVSEP);

417 if (!end)

419 errcode(ERRCODE_PROTOCOL_VIOLATION),

420 errmsg("malformed OAUTHBEARER message"),

421 errdetail("Message contains an unterminated key/value pair."));

422 *end = '\0';

423

424 if (pos == end)

425 {

426

427 *input = pos + 1;

428 return auth;

429 }

430

431

432

433

434 sep = strchr(pos, '=');

435 if (!sep)

437 errcode(ERRCODE_PROTOCOL_VIOLATION),

438 errmsg("malformed OAUTHBEARER message"),

439 errdetail("Message contains a key without a value."));

440 *sep = '\0';

441

442

443 key = pos;

446

448 {

449 if (auth)

451 errcode(ERRCODE_PROTOCOL_VIOLATION),

452 errmsg("malformed OAUTHBEARER message"),

453 errdetail("Message contains multiple auth values."));

454

456 }

457 else

458 {

459

460

461

462

463

464 }

465

466

467 pos = end + 1;

468 }

469

471 errcode(ERRCODE_PROTOCOL_VIOLATION),

472 errmsg("malformed OAUTHBEARER message"),

473 errdetail("Message did not contain a final terminator."));

474

476 return NULL;

477}

478

479

480

481

482

483

484static void

486{

489

490

491

492

493

494

495

498 errcode(ERRCODE_INTERNAL_ERROR),

499 errmsg("OAuth is not properly configured for this user"),

500 errdetail_log("The issuer and scope parameters must be set in pg_hba.conf."));

501

502

503

504

505

508 if (strstr(ctx->issuer, "/.well-known/") == NULL)

510

512

513

514

515

516

517

519

523

526

528

530 *outputlen = buf.len;

531}

532

533

534

535

536

537

538

539

540

541

542

543

544

545

546

547

548

549

550

551

552

553

554

555static const char *

557{

558 size_t span;

559 const char *token;

560 static const char *const b64token_allowed_set =

561 "abcdefghijklmnopqrstuvwxyz"

562 "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

563 "0123456789-._~+/";

564

565

567

568 if (header[0] == '\0')

569 {

570

571

572

573

574

575

576

577

578 return NULL;

579 }

580

582 {

584 errcode(ERRCODE_PROTOCOL_VIOLATION),

585 errmsg("malformed OAuth bearer token"),

586 errdetail_log("Client response indicated a non-Bearer authentication scheme."));

587 return NULL;

588 }

589

590

592

593

594 while (*token == ' ')

596

597

599 {

601 errcode(ERRCODE_PROTOCOL_VIOLATION),

602 errmsg("malformed OAuth bearer token"),

604 return NULL;

605 }

606

607

608

609

610

611 span = strspn(token, b64token_allowed_set);

612 while (token[span] == '=')

613 span++;

614

615 if (token[span] != '\0')

616 {

617

618

619

620

621

623 errcode(ERRCODE_PROTOCOL_VIOLATION),

624 errmsg("malformed OAuth bearer token"),

625 errdetail_log("Bearer token is not in the correct format."));

626 return NULL;

627 }

628

630}

631

632

633

634

635

636

637static bool

639{

640 int map_status;

642 const char *token;

643 bool status;

644

645

647 return false;

648

649

650

651

652

655 errcode(ERRCODE_INTERNAL_ERROR),

656 errmsg("validation of OAuth token requested without a validator loaded"));

657

658

661 port->user_name, ret))

662 {

664 errcode(ERRCODE_INTERNAL_ERROR),

665 errmsg("internal error in OAuth validator module"));

666 return false;

667 }

668

669

670

671

672

675

677 {

679 errmsg("OAuth bearer authentication failed for user \"%s\"",

680 port->user_name),

681 errdetail_log("Validator failed to authorize the provided token."));

682

683 status = false;

685 }

686

687 if (port->hba->oauth_skip_usermap)

688 {

689

690

691

692

693

694

695 status = true;

697 }

698

699

701 {

703 errmsg("OAuth bearer authentication failed for user \"%s\"",

704 port->user_name),

705 errdetail_log("Validator provided no identity."));

706

707 status = false;

709 }

710

711

714 status = (map_status == STATUS_OK);

715

717

718

719

720

721

725

726 return status;

727}

728

729

730

731

732

733

734

735

736

737static void

739{

742

743

744

745

746

747

748 Assert(libname && *libname);

749

752 false, NULL);

753

754

755

756

757

758 if (validator_init == NULL)

760 errmsg("%s module \"%s\" must define the symbol %s",

761 "OAuth validator", libname, "_PG_oauth_validator_module_init"));

762

765

766

767

768

769

770

773 errmsg("%s module \"%s\": magic number mismatch",

774 "OAuth validator", libname),

775 errdetail("Server has magic number 0x%08X, module has 0x%08X.",

777

778

779

780

781

784 errmsg("%s module \"%s\" must provide a %s callback",

785 "OAuth validator", libname, "validate_cb"));

786

787

790

793

794

795 mcb = palloc0(sizeof(*mcb));

797

799}

800

801

802

803

804

805static void

807{

810}

811

812

813

814

815

816

817

818

819bool

821{

823 const char *file_name = hbaline->sourcefile;

824 char *rawstring;

826

827 *err_msg = NULL;

828

830 {

832 errcode(ERRCODE_CONFIG_FILE_ERROR),

833 errmsg("oauth_validator_libraries must be set for authentication method %s",

834 "oauth"),

835 errcontext("line %d of configuration file \"%s\"",

836 line_num, file_name));

837 *err_msg = psprintf("oauth_validator_libraries must be set for authentication method %s",

838 "oauth");

839 return false;

840 }

841

842

844

846 {

847

849 errcode(ERRCODE_CONFIG_FILE_ERROR),

850 errmsg("invalid list syntax in parameter \"%s\"",

851 "oauth_validator_libraries"));

852 *err_msg = psprintf("invalid list syntax in parameter \"%s\"",

853 "oauth_validator_libraries");

854 goto done;

855 }

856

858 {

859 if (elemlist->length == 1)

860 {

862 goto done;

863 }

864

866 errcode(ERRCODE_CONFIG_FILE_ERROR),

867 errmsg("authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options"),

868 errcontext("line %d of configuration file \"%s\"",

869 line_num, file_name));

870 *err_msg = "authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options";

871 goto done;

872 }

873

875 {

877 goto done;

878 }

879

881 errcode(ERRCODE_INVALID_PARAMETER_VALUE),

882 errmsg("validator \"%s\" is not permitted by %s",

884 errcontext("line %d of configuration file \"%s\"",

885 line_num, file_name));

886 *err_msg = psprintf("validator \"%s\" is not permitted by %s",

888

889done:

891 pfree(rawstring);

892

893 return (*err_msg == NULL);

894}

static void shutdown_validator_library(void *arg)

static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)

static bool validate(Port *port, const char *auth)

char * oauth_validator_libraries_string

static const OAuthValidatorCallbacks * ValidatorCallbacks

static void validate_kvpair(const char *key, const char *val)

bool check_oauth_validator(HbaLine *hbaline, int elevel, char **err_msg)

static void * oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)

static const char * validate_token_format(const char *header)

static char * sanitize_char(char c)

static void oauth_get_mechanisms(Port *port, StringInfo buf)

static int oauth_exchange(void *opaq, const char *input, int inputlen, char **output, int *outputlen, const char **logdetail)

static char * parse_kvpairs_for_auth(char **input)

static void load_validator_library(const char *libname)

const pg_be_sasl_mech pg_be_oauth_mech

static ValidatorModuleState * validator_module_state

void set_authn_id(Port *port, const char *id)

#define PG_MAX_AUTH_TOKEN_LENGTH

static void cleanup(void)

void * load_external_function(const char *filename, const char *funcname, bool signalNotFound, void **filehandle)

int errdetail(const char *fmt,...)

int errcode(int sqlerrcode)

int errmsg(const char *fmt,...)

int errdetail_log(const char *fmt,...)

#define ereport(elevel,...)

Assert(PointerIsAligned(start, uint64))

int check_usermap(const char *usermap_name, const char *pg_user, const char *system_user, bool case_insensitive)

void escape_json(StringInfo buf, const char *str)

void list_free_deep(List *list)

char * pstrdup(const char *in)

void MemoryContextRegisterResetCallback(MemoryContext context, MemoryContextCallback *cb)

void pfree(void *pointer)

void * palloc0(Size size)

MemoryContext CurrentMemoryContext

ClientConnectionInfo MyClientConnectionInfo

#define PG_OAUTH_VALIDATOR_MAGIC

const OAuthValidatorCallbacks *(* OAuthValidatorModuleInit)(void)

#define foreach_ptr(type, var, lst)

void explicit_bzero(void *buf, size_t len)

int pg_strncasecmp(const char *s1, const char *s2, size_t n)

char * psprintf(const char *fmt,...)

#define PG_SASL_EXCHANGE_FAILURE

#define PG_SASL_EXCHANGE_CONTINUE

#define PG_SASL_EXCHANGE_SUCCESS

void appendStringInfoString(StringInfo str, const char *s)

void appendStringInfoChar(StringInfo str, char ch)

void initStringInfo(StringInfo str)

MemoryContextCallbackFunction func

ValidatorShutdownCB shutdown_cb

ValidatorValidateCB validate_cb

ValidatorStartupCB startup_cb

void(* get_mechanisms)(Port *port, StringInfo buf)

bool SplitDirectoriesString(char *rawstring, char separator, List **namelist)