PostgreSQL Source Code: src/backend/libpq/auth-oauth.c Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
17
19#include <fcntl.h>
20
32
33
35
37static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
39 char **output, int *outputlen, const char **logdetail);
40
43
46
47
52
54};
55
56
58{
62};
63
64
66{
71};
72
77
78
79#define KVSEP 0x01
80#define AUTH_KEY "auth"
81#define BEARER_SCHEME "Bearer "
82
83
84
85
86
87
88static void
90{
91
94}
95
96
97
98
99
100
101static void *
103{
105
108 errcode(ERRCODE_PROTOCOL_VIOLATION),
109 errmsg("client selected an invalid SASL authentication mechanism"));
110
111 ctx = palloc0(sizeof(*ctx));
112
115
117 ctx->issuer = port->hba->oauth_issuer;
118 ctx->scope = port->hba->oauth_scope;
119
121
122 return ctx;
123}
124
125
126
127
128
129
130
131
132
133static int
135 char **output, int *outputlen, const char **logdetail)
136{
137 char *input_copy;
138 char *p;
139 char cbind_flag;
140 char *auth;
141 int status;
142
144
146 *outputlen = -1;
147
148
149
150
151
152
153
154 if (input == NULL)
155 {
157
159 *outputlen = 0;
161 }
162
163
164
165
166 if (inputlen == 0)
168 errcode(ERRCODE_PROTOCOL_VIOLATION),
169 errmsg("malformed OAUTHBEARER message"),
170 errdetail("The message is empty."));
171 if (inputlen != strlen(input))
173 errcode(ERRCODE_PROTOCOL_VIOLATION),
174 errmsg("malformed OAUTHBEARER message"),
175 errdetail("Message length does not match input length."));
176
177 switch (ctx->state)
178 {
180
181 break;
182
184
185
186
187
188
189 if (inputlen != 1 || *input != KVSEP)
191 errcode(ERRCODE_PROTOCOL_VIOLATION),
192 errmsg("malformed OAUTHBEARER message"),
193 errdetail("Client did not send a kvsep response."));
194
195
198
199 default:
200 elog(ERROR, "invalid OAUTHBEARER exchange state");
202 }
203
204
206
207
208
209
210
211
212
213
214 cbind_flag = *p;
215 switch (cbind_flag)
216 {
217 case 'p':
219 errcode(ERRCODE_PROTOCOL_VIOLATION),
220 errmsg("malformed OAUTHBEARER message"),
221 errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data."));
222 break;
223
224 case 'y':
225 case 'n':
226 p++;
227 if (*p != ',')
229 errcode(ERRCODE_PROTOCOL_VIOLATION),
230 errmsg("malformed OAUTHBEARER message"),
231 errdetail("Comma expected, but found character \"%s\".",
233 p++;
234 break;
235
236 default:
238 errcode(ERRCODE_PROTOCOL_VIOLATION),
239 errmsg("malformed OAUTHBEARER message"),
240 errdetail("Unexpected channel-binding flag \"%s\".",
242 }
243
244
245
246
247 if (*p == 'a')
249 errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
250 errmsg("client uses authorization identity, but it is not supported"));
251 if (*p != ',')
253 errcode(ERRCODE_PROTOCOL_VIOLATION),
254 errmsg("malformed OAUTHBEARER message"),
255 errdetail("Unexpected attribute \"%s\" in client-first-message.",
257 p++;
258
259
262 errcode(ERRCODE_PROTOCOL_VIOLATION),
263 errmsg("malformed OAUTHBEARER message"),
264 errdetail("Key-value separator expected, but found character \"%s\".",
266 p++;
267
269 if (!auth)
271 errcode(ERRCODE_PROTOCOL_VIOLATION),
272 errmsg("malformed OAUTHBEARER message"),
273 errdetail("Message does not contain an auth value."));
274
275
276 if (*p)
278 errcode(ERRCODE_PROTOCOL_VIOLATION),
279 errmsg("malformed OAUTHBEARER message"),
280 errdetail("Message contains additional data after the final terminator."));
281
283 {
285
288 }
289 else
290 {
293 }
294
295
297
298 return status;
299}
300
301
302
303
304
305
306
307
308
309static char *
311{
312 static char buf[5];
313
314 if (c >= 0x21 && c <= 0x7E)
316 else
318 return buf;
319}
320
321
322
323
324
325
326static void
328{
329
330
331
332
333 static const char *key_allowed_set =
334 "abcdefghijklmnopqrstuvwxyz"
335 "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
336
337 size_t span;
338
339 if ([0])
341 errcode(ERRCODE_PROTOCOL_VIOLATION),
342 errmsg("malformed OAUTHBEARER message"),
343 errdetail("Message contains an empty key name."));
344
345 span = strspn(key, key_allowed_set);
346 if (key[span] != '\0')
348 errcode(ERRCODE_PROTOCOL_VIOLATION),
349 errmsg("malformed OAUTHBEARER message"),
350 errdetail("Message contains an invalid key name."));
351
352
353
354
355
356
357
358
360 {
361 if (0x21 <= *val && *val <= 0x7E)
362 continue;
363
364 switch (*val)
365 {
366 case ' ':
367 case '\t':
368 case '\r':
369 case '\n':
370 continue;
371
372 default:
374 errcode(ERRCODE_PROTOCOL_VIOLATION),
375 errmsg("malformed OAUTHBEARER message"),
376 errdetail("Message contains an invalid value."));
377 }
378 }
379}
380
381
382
383
384
385static char *
387{
388 char *pos = *input;
389 char *auth = NULL;
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405 while (*pos)
406 {
407 char *end;
408 char *sep;
409 char *key;
411
412
413
414
415
416 end = strchr(pos, KVSEP);
417 if (!end)
419 errcode(ERRCODE_PROTOCOL_VIOLATION),
420 errmsg("malformed OAUTHBEARER message"),
421 errdetail("Message contains an unterminated key/value pair."));
422 *end = '\0';
423
424 if (pos == end)
425 {
426
427 *input = pos + 1;
428 return auth;
429 }
430
431
432
433
434 sep = strchr(pos, '=');
435 if (!sep)
437 errcode(ERRCODE_PROTOCOL_VIOLATION),
438 errmsg("malformed OAUTHBEARER message"),
439 errdetail("Message contains a key without a value."));
440 *sep = '\0';
441
442
443 key = pos;
446
448 {
449 if (auth)
451 errcode(ERRCODE_PROTOCOL_VIOLATION),
452 errmsg("malformed OAUTHBEARER message"),
453 errdetail("Message contains multiple auth values."));
454
456 }
457 else
458 {
459
460
461
462
463
464 }
465
466
467 pos = end + 1;
468 }
469
471 errcode(ERRCODE_PROTOCOL_VIOLATION),
472 errmsg("malformed OAUTHBEARER message"),
473 errdetail("Message did not contain a final terminator."));
474
476 return NULL;
477}
478
479
480
481
482
483
484static void
486{
489
490
491
492
493
494
495
498 errcode(ERRCODE_INTERNAL_ERROR),
499 errmsg("OAuth is not properly configured for this user"),
500 errdetail_log("The issuer and scope parameters must be set in pg_hba.conf."));
501
502
503
504
505
508 if (strstr(ctx->issuer, "/.well-known/") == NULL)
510
512
513
514
515
516
517
519
523
526
528
530 *outputlen = buf.len;
531}
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555static const char *
557{
558 size_t span;
559 const char *token;
560 static const char *const b64token_allowed_set =
561 "abcdefghijklmnopqrstuvwxyz"
562 "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
563 "0123456789-._~+/";
564
565
567
568 if (header[0] == '\0')
569 {
570
571
572
573
574
575
576
577
578 return NULL;
579 }
580
582 {
584 errcode(ERRCODE_PROTOCOL_VIOLATION),
585 errmsg("malformed OAuth bearer token"),
586 errdetail_log("Client response indicated a non-Bearer authentication scheme."));
587 return NULL;
588 }
589
590
592
593
594 while (*token == ' ')
596
597
599 {
601 errcode(ERRCODE_PROTOCOL_VIOLATION),
602 errmsg("malformed OAuth bearer token"),
604 return NULL;
605 }
606
607
608
609
610
611 span = strspn(token, b64token_allowed_set);
612 while (token[span] == '=')
613 span++;
614
615 if (token[span] != '\0')
616 {
617
618
619
620
621
623 errcode(ERRCODE_PROTOCOL_VIOLATION),
624 errmsg("malformed OAuth bearer token"),
625 errdetail_log("Bearer token is not in the correct format."));
626 return NULL;
627 }
628
630}
631
632
633
634
635
636
637static bool
639{
640 int map_status;
642 const char *token;
643 bool status;
644
645
647 return false;
648
649
650
651
652
655 errcode(ERRCODE_INTERNAL_ERROR),
656 errmsg("validation of OAuth token requested without a validator loaded"));
657
658
661 port->user_name, ret))
662 {
664 errcode(ERRCODE_INTERNAL_ERROR),
665 errmsg("internal error in OAuth validator module"));
666 return false;
667 }
668
669
670
671
672
675
677 {
679 errmsg("OAuth bearer authentication failed for user \"%s\"",
680 port->user_name),
681 errdetail_log("Validator failed to authorize the provided token."));
682
683 status = false;
685 }
686
687 if (port->hba->oauth_skip_usermap)
688 {
689
690
691
692
693
694
695 status = true;
697 }
698
699
701 {
703 errmsg("OAuth bearer authentication failed for user \"%s\"",
704 port->user_name),
705 errdetail_log("Validator provided no identity."));
706
707 status = false;
709 }
710
711
714 status = (map_status == STATUS_OK);
715
717
718
719
720
721
725
726 return status;
727}
728
729
730
731
732
733
734
735
736
737static void
739{
742
743
744
745
746
747
748 Assert(libname && *libname);
749
752 false, NULL);
753
754
755
756
757
758 if (validator_init == NULL)
760 errmsg("%s module \"%s\" must define the symbol %s",
761 "OAuth validator", libname, "_PG_oauth_validator_module_init"));
762
765
766
767
768
769
770
773 errmsg("%s module \"%s\": magic number mismatch",
774 "OAuth validator", libname),
775 errdetail("Server has magic number 0x%08X, module has 0x%08X.",
777
778
779
780
781
784 errmsg("%s module \"%s\" must provide a %s callback",
785 "OAuth validator", libname, "validate_cb"));
786
787
790
793
794
795 mcb = palloc0(sizeof(*mcb));
797
799}
800
801
802
803
804
805static void
807{
810}
811
812
813
814
815
816
817
818
819bool
821{
823 const char *file_name = hbaline->sourcefile;
824 char *rawstring;
826
827 *err_msg = NULL;
828
830 {
832 errcode(ERRCODE_CONFIG_FILE_ERROR),
833 errmsg("oauth_validator_libraries must be set for authentication method %s",
834 "oauth"),
835 errcontext("line %d of configuration file \"%s\"",
836 line_num, file_name));
837 *err_msg = psprintf("oauth_validator_libraries must be set for authentication method %s",
838 "oauth");
839 return false;
840 }
841
842
844
846 {
847
849 errcode(ERRCODE_CONFIG_FILE_ERROR),
850 errmsg("invalid list syntax in parameter \"%s\"",
851 "oauth_validator_libraries"));
852 *err_msg = psprintf("invalid list syntax in parameter \"%s\"",
853 "oauth_validator_libraries");
854 goto done;
855 }
856
858 {
859 if (elemlist->length == 1)
860 {
862 goto done;
863 }
864
866 errcode(ERRCODE_CONFIG_FILE_ERROR),
867 errmsg("authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options"),
868 errcontext("line %d of configuration file \"%s\"",
869 line_num, file_name));
870 *err_msg = "authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options";
871 goto done;
872 }
873
875 {
877 goto done;
878 }
879
881 errcode(ERRCODE_INVALID_PARAMETER_VALUE),
882 errmsg("validator \"%s\" is not permitted by %s",
884 errcontext("line %d of configuration file \"%s\"",
885 line_num, file_name));
886 *err_msg = psprintf("validator \"%s\" is not permitted by %s",
888
889done:
891 pfree(rawstring);
892
893 return (*err_msg == NULL);
894}
static void shutdown_validator_library(void *arg)
static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
static bool validate(Port *port, const char *auth)
char * oauth_validator_libraries_string
static const OAuthValidatorCallbacks * ValidatorCallbacks
static void validate_kvpair(const char *key, const char *val)
bool check_oauth_validator(HbaLine *hbaline, int elevel, char **err_msg)
static void * oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
static const char * validate_token_format(const char *header)
static char * sanitize_char(char c)
static void oauth_get_mechanisms(Port *port, StringInfo buf)
static int oauth_exchange(void *opaq, const char *input, int inputlen, char **output, int *outputlen, const char **logdetail)
static char * parse_kvpairs_for_auth(char **input)
static void load_validator_library(const char *libname)
const pg_be_sasl_mech pg_be_oauth_mech
static ValidatorModuleState * validator_module_state
void set_authn_id(Port *port, const char *id)
#define PG_MAX_AUTH_TOKEN_LENGTH
static void cleanup(void)
void * load_external_function(const char *filename, const char *funcname, bool signalNotFound, void **filehandle)
int errdetail(const char *fmt,...)
int errcode(int sqlerrcode)
int errmsg(const char *fmt,...)
int errdetail_log(const char *fmt,...)
#define ereport(elevel,...)
Assert(PointerIsAligned(start, uint64))
int check_usermap(const char *usermap_name, const char *pg_user, const char *system_user, bool case_insensitive)
void escape_json(StringInfo buf, const char *str)
void list_free_deep(List *list)
char * pstrdup(const char *in)
void MemoryContextRegisterResetCallback(MemoryContext context, MemoryContextCallback *cb)
void pfree(void *pointer)
void * palloc0(Size size)
MemoryContext CurrentMemoryContext
ClientConnectionInfo MyClientConnectionInfo
#define PG_OAUTH_VALIDATOR_MAGIC
const OAuthValidatorCallbacks *(* OAuthValidatorModuleInit)(void)
#define foreach_ptr(type, var, lst)
void explicit_bzero(void *buf, size_t len)
int pg_strncasecmp(const char *s1, const char *s2, size_t n)
char * psprintf(const char *fmt,...)
#define PG_SASL_EXCHANGE_FAILURE
#define PG_SASL_EXCHANGE_CONTINUE
#define PG_SASL_EXCHANGE_SUCCESS
void appendStringInfoString(StringInfo str, const char *s)
void appendStringInfoChar(StringInfo str, char ch)
void initStringInfo(StringInfo str)
MemoryContextCallbackFunction func
ValidatorShutdownCB shutdown_cb
ValidatorValidateCB validate_cb
ValidatorStartupCB startup_cb
void(* get_mechanisms)(Port *port, StringInfo buf)
bool SplitDirectoriesString(char *rawstring, char separator, List **namelist)