PostgreSQL Source Code: src/interfaces/libpq/fe-auth-scram.c Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
16
22
23
24
26 const char *sasl_mechanism);
28 char *input, int inputlen,
29 char **output, int *outputlen);
32
38};
39
40
41
42
43
44typedef enum
45{
51
52typedef struct
53{
55
56
60
61
64
65
70
71
77
78
82
88 const char **errstr);
90 const char *client_final_message_without_proof,
91 uint8 *result, const char **errstr);
92
93
94
95
96static void *
99 const char *sasl_mechanism)
100{
102 char *prep_password;
104
105 Assert(sasl_mechanism != NULL);
106
109 return NULL;
115
116 state->sasl_mechanism = strdup(sasl_mechanism);
117 if (->sasl_mechanism)
118 {
120 return NULL;
121 }
122
124 {
125
128 {
131 return NULL;
132 }
134 {
135 prep_password = strdup(password);
136 if (!prep_password)
137 {
140 return NULL;
141 }
142 }
143 state->password = prep_password;
144 }
145
147}
148
149
150
151
152
153
154
155
156
157static bool
159{
161
162
163 if (state == NULL)
164 return false;
165
166
168 return false;
169
170
172 return false;
173
174
175 return true;
176}
177
178
179
180
181static void
183{
185
188
189
191 free(state->client_first_message_bare);
192 free(state->client_final_message_without_proof);
193
194
195 free(state->server_first_message);
198
199
200 free(state->server_final_message);
201
203}
204
205
206
207
210 char *input, int inputlen,
211 char **output, int *outputlen)
212{
215 const char *errstr = NULL;
216
218 *outputlen = 0;
219
220
221
222
223
225 {
226 if (inputlen == 0)
227 {
230 }
231 if (inputlen != strlen(input))
232 {
235 }
236 }
237
238 switch (state->state)
239 {
241
245
246 *outputlen = strlen(*output);
249
251
254
258
259 *outputlen = strlen(*output);
262
264 {
265 bool match;
266
267
270
271
272
273
274
276 {
279 }
280
281 if (!match)
282 {
284 }
286 state->conn->client_finished_auth = true;
288 }
289
290 default:
291
293 break;
294 }
295
297}
298
299
300
301
302
303
304
305
306
307static char *
309{
310 char *begin = *input;
311 char *end;
312
313 if (*begin != attr)
314 {
316 "malformed SCRAM message (attribute \"%c\" expected)",
317 attr);
318 return NULL;
319 }
320 begin++;
321
322 if (*begin != '=')
323 {
325 "malformed SCRAM message (expected character \"=\" for attribute \"%c\")",
326 attr);
327 return NULL;
328 }
329 begin++;
330
331 end = begin;
332 while (*end && *end != ',')
333 end++;
334
335 if (*end)
336 {
337 *end = '\0';
338 *input = end + 1;
339 }
340 else
342
343 return begin;
344}
345
346
347
348
349static char *
351{
354 char *result;
355 int channel_info_len;
356 int encoded_len;
358
359
360
361
362
364 {
366 return NULL;
367 }
368
370
371 state->client_nonce = malloc(encoded_len + 1);
372 if (state->client_nonce == NULL)
373 {
375 return NULL;
376 }
378 state->client_nonce, encoded_len);
379 if (encoded_len < 0)
380 {
382 return NULL;
383 }
384 state->client_nonce[encoded_len] = '\0';
385
386
387
388
389
390
391
392
394
395
396
397
399 {
402 }
403#ifdef USE_SSL
406 {
407
408
409
411 }
412#endif
413 else
414 {
415
416
417
419 }
420
422 goto oom_error;
423
424 channel_info_len = buf.len;
425
428 goto oom_error;
429
430
431
432
433
434 state->client_first_message_bare = strdup(buf.data + channel_info_len + 2);
435 if (->client_first_message_bare)
436 goto oom_error;
437
438 result = strdup(buf.data);
439 if (result == NULL)
440 goto oom_error;
441
443 return result;
444
445oom_error:
448 return NULL;
449}
450
451
452
453
454static char *
456{
460 char *result;
461 int encoded_len;
462 const char *errstr = NULL;
463
465
466
467
468
469
470
471
472
473
475 {
476#ifdef USE_SSL
477 char *cbind_data = NULL;
478 size_t cbind_data_len = 0;
479 size_t cbind_header_len;
480 char *cbind_input;
481 size_t cbind_input_len;
482 int encoded_cbind_len;
483
484
485 cbind_data =
487 &cbind_data_len);
488 if (cbind_data == NULL)
489 {
490
492 return NULL;
493 }
494
496
497
498 cbind_header_len = strlen("p=tls-server-end-point,,");
499 cbind_input_len = cbind_header_len + cbind_data_len;
500 cbind_input = malloc(cbind_input_len);
501 if (!cbind_input)
502 {
503 free(cbind_data);
504 goto oom_error;
505 }
506 memcpy(cbind_input, "p=tls-server-end-point,,", cbind_header_len);
507 memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
508
511 {
512 free(cbind_data);
513 free(cbind_input);
514 goto oom_error;
515 }
516 encoded_cbind_len = pg_b64_encode((uint8 *) cbind_input, cbind_input_len,
518 encoded_cbind_len);
519 if (encoded_cbind_len < 0)
520 {
521 free(cbind_data);
522 free(cbind_input);
525 "could not encode cbind data for channel binding\n");
526 return NULL;
527 }
528 buf.len += encoded_cbind_len;
530
531 free(cbind_data);
532 free(cbind_input);
533#else
534
535
536
537
540 "channel binding not supported by this build\n");
541 return NULL;
542#endif
543 }
544#ifdef USE_SSL
548#endif
549 else
551
553 goto oom_error;
554
557 goto oom_error;
558
559 state->client_final_message_without_proof = strdup(buf.data);
560 if (state->client_final_message_without_proof == NULL)
561 goto oom_error;
562
563
565 state->client_final_message_without_proof,
566 client_proof, &errstr))
567 {
570 return NULL;
571 }
572
576 goto oom_error;
578 state->key_length,
580 encoded_len);
581 if (encoded_len < 0)
582 {
585 return NULL;
586 }
587 buf.len += encoded_len;
589
590 result = strdup(buf.data);
591 if (result == NULL)
592 goto oom_error;
593
595 return result;
596
597oom_error:
600 return NULL;
601}
602
603
604
605
606static bool
608{
610 char *iterations_str;
611 char *endptr;
612 char *encoded_salt;
613 char *nonce;
614 int decoded_salt_len;
615
616 state->server_first_message = strdup(input);
617 if (state->server_first_message == NULL)
618 {
620 return false;
621 }
622
623
626 if (nonce == NULL)
627 {
628
629 return false;
630 }
631
632
633 if (strlen(nonce) < strlen(state->client_nonce) ||
634 memcmp(nonce, state->client_nonce, strlen(state->client_nonce)) != 0)
635 {
637 return false;
638 }
639
640 state->nonce = strdup(nonce);
641 if (state->nonce == NULL)
642 {
644 return false;
645 }
646
648 if (encoded_salt == NULL)
649 {
650
651 return false;
652 }
653 decoded_salt_len = pg_b64_dec_len(strlen(encoded_salt));
655 if (state->salt == NULL)
656 {
658 return false;
659 }
661 strlen(encoded_salt),
663 decoded_salt_len);
664 if (state->saltlen < 0)
665 {
667 return false;
668 }
669
671 if (iterations_str == NULL)
672 {
673
674 return false;
675 }
676 state->iterations = strtol(iterations_str, &endptr, 10);
677 if (*endptr != '\0' || state->iterations < 1)
678 {
680 return false;
681 }
682
683 if (*input != '\0')
685
686 return true;
687}
688
689
690
691
692static bool
694{
696 char *encoded_server_signature;
697 uint8 *decoded_server_signature;
698 int server_signature_len;
699
700 state->server_final_message = strdup(input);
701 if (->server_final_message)
702 {
704 return false;
705 }
706
707
708 if (*input == 'e')
709 {
712
714 {
715
716 return false;
717 }
720 return false;
721 }
722
723
726 if (encoded_server_signature == NULL)
727 {
728
729 return false;
730 }
731
732 if (*input != '\0')
734
735 server_signature_len = pg_b64_dec_len(strlen(encoded_server_signature));
736 decoded_server_signature = malloc(server_signature_len);
737 if (!decoded_server_signature)
738 {
740 return false;
741 }
742
743 server_signature_len = pg_b64_decode(encoded_server_signature,
744 strlen(encoded_server_signature),
745 decoded_server_signature,
746 server_signature_len);
747 if (server_signature_len != state->key_length)
748 {
749 free(decoded_server_signature);
751 return false;
752 }
753 memcpy(state->ServerSignature, decoded_server_signature,
754 state->key_length);
755 free(decoded_server_signature);
756
757 return true;
758}
759
760
761
762
763
764
765static bool
767 const char *client_final_message_without_proof,
768 uint8 *result, const char **errstr)
769{
773 int i;
775
777 if (ctx == NULL)
778 {
780 return false;
781 }
782
783 if (state->conn->scram_client_key_binary)
784 {
786 }
787 else
788 {
789
790
791
792
795 state->iterations, state->SaltedPassword,
796 errstr) < 0 ||
798 state->key_length, ClientKey, errstr) < 0)
799 {
800
802 return false;
803 }
804 }
805
806 if (scram_H(ClientKey, state->hash_type, state->key_length, StoredKey, errstr) < 0)
807 {
809 return false;
810 }
811
814 (uint8 *) state->client_first_message_bare,
815 strlen(state->client_first_message_bare)) < 0 ||
818 (uint8 *) state->server_first_message,
819 strlen(state->server_first_message)) < 0 ||
822 (uint8 *) client_final_message_without_proof,
823 strlen(client_final_message_without_proof)) < 0 ||
825 {
828 return false;
829 }
830
831 for (i = 0; i < state->key_length; i++)
832 result[i] = ClientKey[i] ^ ClientSignature[i];
833
835 return true;
836}
837
838
839
840
841
842
843
844
845static bool
847 const char **errstr)
848{
852
854 if (ctx == NULL)
855 {
857 return false;
858 }
859
860 if (state->conn->scram_server_key_binary)
861 {
863 }
864 else
865 {
867 state->key_length, ServerKey, errstr) < 0)
868 {
869
871 return false;
872 }
873 }
874
875
878 (uint8 *) state->client_first_message_bare,
879 strlen(state->client_first_message_bare)) < 0 ||
882 (uint8 *) state->server_first_message,
883 strlen(state->server_first_message)) < 0 ||
886 (uint8 *) state->client_final_message_without_proof,
887 strlen(state->client_final_message_without_proof)) < 0 ||
889 state->key_length) < 0)
890 {
893 return false;
894 }
895
897
898
899 if (memcmp(expected_ServerSignature, state->ServerSignature,
900 state->key_length) != 0)
901 *match = false;
902 else
903 *match = true;
904
905 return true;
906}
907
908
909
910
911
912
913
914char *
916{
917 char *prep_password;
920 char *result;
921
922
923
924
925
926
927
930 {
932 return NULL;
933 }
935 password = (const char *) prep_password;
936
937
939 {
940 *errstr = libpq_gettext("could not generate random salt");
941 free(prep_password);
942 return NULL;
943 }
944
948 errstr);
949
950 free(prep_password);
951
952 return result;
953}
int pg_b64_enc_len(int srclen)
int pg_b64_encode(const uint8 *src, int len, char *dst, int dstlen)
int pg_b64_dec_len(int srclen)
int pg_b64_decode(const char *src, int len, uint8 *dst, int dstlen)
int errmsg(const char *fmt,...)
static char * build_client_first_message(fe_scram_state *state)
const pg_fe_sasl_mech pg_scram_mech
static bool verify_server_signature(fe_scram_state *state, bool *match, const char **errstr)
static void * scram_init(PGconn *conn, const char *password, const char *sasl_mechanism)
static char * read_attr_value(char **input, char attr, PQExpBuffer errorMessage)
static void scram_free(void *opaq)
static bool read_server_first_message(fe_scram_state *state, char *input)
static bool scram_channel_bound(void *opaq)
char * pg_fe_scram_build_secret(const char *password, int iterations, const char **errstr)
static bool calculate_client_proof(fe_scram_state *state, const char *client_final_message_without_proof, uint8 *result, const char **errstr)
static SASLStatus scram_exchange(void *opaq, bool final, char *input, int inputlen, char **output, int *outputlen)
static char * build_client_final_message(fe_scram_state *state)
static bool read_server_final_message(fe_scram_state *state, char *input)
void libpq_append_error(PQExpBuffer errorMessage, const char *fmt,...)
char * pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len)
Assert(PointerIsAligned(start, uint64))
pg_hmac_ctx * pg_hmac_create(pg_cryptohash_type type)
void pg_hmac_free(pg_hmac_ctx *ctx)
const char * pg_hmac_error(pg_hmac_ctx *ctx)
int pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len)
int pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
int pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
void libpq_append_conn_error(PGconn *conn, const char *fmt,...)
bool pg_strong_random(void *buf, size_t len)
void initPQExpBuffer(PQExpBuffer str)
int enlargePQExpBuffer(PQExpBuffer str, size_t needed)
void appendPQExpBuffer(PQExpBuffer str, const char *fmt,...)
void appendPQExpBufferChar(PQExpBuffer str, char ch)
void appendPQExpBufferStr(PQExpBuffer str, const char *data)
void termPQExpBuffer(PQExpBuffer str)
#define PQExpBufferDataBroken(buf)
pg_saslprep_rc pg_saslprep(const char *input, char **output)
int scram_ServerKey(const uint8 *salted_password, pg_cryptohash_type hash_type, int key_length, uint8 *result, const char **errstr)
int scram_SaltedPassword(const char *password, pg_cryptohash_type hash_type, int key_length, const uint8 *salt, int saltlen, int iterations, uint8 *result, const char **errstr)
char * scram_build_secret(pg_cryptohash_type hash_type, int key_length, const uint8 *salt, int saltlen, int iterations, const char *password, const char **errstr)
int scram_ClientKey(const uint8 *salted_password, pg_cryptohash_type hash_type, int key_length, uint8 *result, const char **errstr)
int scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length, uint8 *result, const char **errstr)
#define SCRAM_SHA_256_PLUS_NAME
#define SCRAM_RAW_NONCE_LEN
#define SCRAM_DEFAULT_SALT_LEN
#define SCRAM_MAX_KEY_LEN
#define SCRAM_SHA_256_KEY_LEN
char * client_final_message_without_proof
fe_scram_state_enum state
char * client_first_message_bare
char * server_final_message
char * server_first_message
pg_cryptohash_type hash_type
PQExpBufferData errorMessage