PostgreSQL Source Code: src/interfaces/libpq/fe-auth-scram.c Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

16

22

23

24

26 const char *sasl_mechanism);

28 char *input, int inputlen,

29 char **output, int *outputlen);

32

38};

39

40

41

42

43

44typedef enum

45{

51

52typedef struct

53{

55

56

60

61

64

65

70

71

77

78

82

88 const char **errstr);

90 const char *client_final_message_without_proof,

91 uint8 *result, const char **errstr);

92

93

94

95

96static void *

99 const char *sasl_mechanism)

100{

102 char *prep_password;

104

105 Assert(sasl_mechanism != NULL);

106

109 return NULL;

115

116 state->sasl_mechanism = strdup(sasl_mechanism);

117 if (state->sasl_mechanism)

118 {

120 return NULL;

121 }

122

124 {

125

128 {

131 return NULL;

132 }

134 {

135 prep_password = strdup(password);

136 if (!prep_password)

137 {

140 return NULL;

141 }

142 }

143 state->password = prep_password;

144 }

145

147}

148

149

150

151

152

153

154

155

156

157static bool

159{

161

162

163 if (state == NULL)

164 return false;

165

166

168 return false;

169

170

172 return false;

173

174

175 return true;

176}

177

178

179

180

181static void

183{

185

188

189

191 free(state->client_first_message_bare);

192 free(state->client_final_message_without_proof);

193

194

195 free(state->server_first_message);

198

199

200 free(state->server_final_message);

201

203}

204

205

206

207

210 char *input, int inputlen,

211 char **output, int *outputlen)

212{

215 const char *errstr = NULL;

216

218 *outputlen = 0;

219

220

221

222

223

225 {

226 if (inputlen == 0)

227 {

230 }

231 if (inputlen != strlen(input))

232 {

235 }

236 }

237

238 switch (state->state)

239 {

241

245

246 *outputlen = strlen(*output);

249

251

254

258

259 *outputlen = strlen(*output);

262

264 {

265 bool match;

266

267

270

271

272

273

274

276 {

279 }

280

281 if (!match)

282 {

284 }

286 state->conn->client_finished_auth = true;

288 }

289

290 default:

291

293 break;

294 }

295

297}

298

299

300

301

302

303

304

305

306

307static char *

309{

310 char *begin = *input;

311 char *end;

312

313 if (*begin != attr)

314 {

316 "malformed SCRAM message (attribute \"%c\" expected)",

317 attr);

318 return NULL;

319 }

320 begin++;

321

322 if (*begin != '=')

323 {

325 "malformed SCRAM message (expected character \"=\" for attribute \"%c\")",

326 attr);

327 return NULL;

328 }

329 begin++;

330

331 end = begin;

332 while (*end && *end != ',')

333 end++;

334

335 if (*end)

336 {

337 *end = '\0';

338 *input = end + 1;

339 }

340 else

342

343 return begin;

344}

345

346

347

348

349static char *

351{

354 char *result;

355 int channel_info_len;

356 int encoded_len;

358

359

360

361

362

364 {

366 return NULL;

367 }

368

370

371 state->client_nonce = malloc(encoded_len + 1);

372 if (state->client_nonce == NULL)

373 {

375 return NULL;

376 }

378 state->client_nonce, encoded_len);

379 if (encoded_len < 0)

380 {

382 return NULL;

383 }

384 state->client_nonce[encoded_len] = '\0';

385

386

387

388

389

390

391

392

394

395

396

397

399 {

402 }

403#ifdef USE_SSL

406 {

407

408

409

411 }

412#endif

413 else

414 {

415

416

417

419 }

420

422 goto oom_error;

423

424 channel_info_len = buf.len;

425

428 goto oom_error;

429

430

431

432

433

434 state->client_first_message_bare = strdup(buf.data + channel_info_len + 2);

435 if (state->client_first_message_bare)

436 goto oom_error;

437

438 result = strdup(buf.data);

439 if (result == NULL)

440 goto oom_error;

441

443 return result;

444

445oom_error:

448 return NULL;

449}

450

451

452

453

454static char *

456{

460 char *result;

461 int encoded_len;

462 const char *errstr = NULL;

463

465

466

467

468

469

470

471

472

473

475 {

476#ifdef USE_SSL

477 char *cbind_data = NULL;

478 size_t cbind_data_len = 0;

479 size_t cbind_header_len;

480 char *cbind_input;

481 size_t cbind_input_len;

482 int encoded_cbind_len;

483

484

485 cbind_data =

487 &cbind_data_len);

488 if (cbind_data == NULL)

489 {

490

492 return NULL;

493 }

494

496

497

498 cbind_header_len = strlen("p=tls-server-end-point,,");

499 cbind_input_len = cbind_header_len + cbind_data_len;

500 cbind_input = malloc(cbind_input_len);

501 if (!cbind_input)

502 {

503 free(cbind_data);

504 goto oom_error;

505 }

506 memcpy(cbind_input, "p=tls-server-end-point,,", cbind_header_len);

507 memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);

508

511 {

512 free(cbind_data);

513 free(cbind_input);

514 goto oom_error;

515 }

516 encoded_cbind_len = pg_b64_encode((uint8 *) cbind_input, cbind_input_len,

518 encoded_cbind_len);

519 if (encoded_cbind_len < 0)

520 {

521 free(cbind_data);

522 free(cbind_input);

525 "could not encode cbind data for channel binding\n");

526 return NULL;

527 }

528 buf.len += encoded_cbind_len;

529 buf.data[buf.len] = '\0';

530

531 free(cbind_data);

532 free(cbind_input);

533#else

534

535

536

537

540 "channel binding not supported by this build\n");

541 return NULL;

542#endif

543 }

544#ifdef USE_SSL

548#endif

549 else

551

553 goto oom_error;

554

557 goto oom_error;

558

559 state->client_final_message_without_proof = strdup(buf.data);

560 if (state->client_final_message_without_proof == NULL)

561 goto oom_error;

562

563

565 state->client_final_message_without_proof,

566 client_proof, &errstr))

567 {

570 return NULL;

571 }

572

576 goto oom_error;

578 state->key_length,

580 encoded_len);

581 if (encoded_len < 0)

582 {

585 return NULL;

586 }

587 buf.len += encoded_len;

588 buf.data[buf.len] = '\0';

589

590 result = strdup(buf.data);

591 if (result == NULL)

592 goto oom_error;

593

595 return result;

596

597oom_error:

600 return NULL;

601}

602

603

604

605

606static bool

608{

610 char *iterations_str;

611 char *endptr;

612 char *encoded_salt;

613 char *nonce;

614 int decoded_salt_len;

615

616 state->server_first_message = strdup(input);

617 if (state->server_first_message == NULL)

618 {

620 return false;

621 }

622

623

626 if (nonce == NULL)

627 {

628

629 return false;

630 }

631

632

633 if (strlen(nonce) < strlen(state->client_nonce) ||

634 memcmp(nonce, state->client_nonce, strlen(state->client_nonce)) != 0)

635 {

637 return false;

638 }

639

640 state->nonce = strdup(nonce);

641 if (state->nonce == NULL)

642 {

644 return false;

645 }

646

648 if (encoded_salt == NULL)

649 {

650

651 return false;

652 }

653 decoded_salt_len = pg_b64_dec_len(strlen(encoded_salt));

655 if (state->salt == NULL)

656 {

658 return false;

659 }

661 strlen(encoded_salt),

663 decoded_salt_len);

664 if (state->saltlen < 0)

665 {

667 return false;

668 }

669

671 if (iterations_str == NULL)

672 {

673

674 return false;

675 }

676 state->iterations = strtol(iterations_str, &endptr, 10);

677 if (*endptr != '\0' || state->iterations < 1)

678 {

680 return false;

681 }

682

683 if (*input != '\0')

685

686 return true;

687}

688

689

690

691

692static bool

694{

696 char *encoded_server_signature;

697 uint8 *decoded_server_signature;

698 int server_signature_len;

699

700 state->server_final_message = strdup(input);

701 if (state->server_final_message)

702 {

704 return false;

705 }

706

707

708 if (*input == 'e')

709 {

712

714 {

715

716 return false;

717 }

720 return false;

721 }

722

723

726 if (encoded_server_signature == NULL)

727 {

728

729 return false;

730 }

731

732 if (*input != '\0')

734

735 server_signature_len = pg_b64_dec_len(strlen(encoded_server_signature));

736 decoded_server_signature = malloc(server_signature_len);

737 if (!decoded_server_signature)

738 {

740 return false;

741 }

742

743 server_signature_len = pg_b64_decode(encoded_server_signature,

744 strlen(encoded_server_signature),

745 decoded_server_signature,

746 server_signature_len);

747 if (server_signature_len != state->key_length)

748 {

749 free(decoded_server_signature);

751 return false;

752 }

753 memcpy(state->ServerSignature, decoded_server_signature,

754 state->key_length);

755 free(decoded_server_signature);

756

757 return true;

758}

759

760

761

762

763

764

765static bool

767 const char *client_final_message_without_proof,

768 uint8 *result, const char **errstr)

769{

773 int i;

775

777 if (ctx == NULL)

778 {

780 return false;

781 }

782

783 if (state->conn->scram_client_key_binary)

784 {

786 }

787 else

788 {

789

790

791

792

795 state->iterations, state->SaltedPassword,

796 errstr) < 0 ||

798 state->key_length, ClientKey, errstr) < 0)

799 {

800

802 return false;

803 }

804 }

805

806 if (scram_H(ClientKey, state->hash_type, state->key_length, StoredKey, errstr) < 0)

807 {

809 return false;

810 }

811

814 (uint8 *) state->client_first_message_bare,

815 strlen(state->client_first_message_bare)) < 0 ||

818 (uint8 *) state->server_first_message,

819 strlen(state->server_first_message)) < 0 ||

822 (uint8 *) client_final_message_without_proof,

823 strlen(client_final_message_without_proof)) < 0 ||

825 {

828 return false;

829 }

830

831 for (i = 0; i < state->key_length; i++)

832 result[i] = ClientKey[i] ^ ClientSignature[i];

833

835 return true;

836}

837

838

839

840

841

842

843

844

845static bool

847 const char **errstr)

848{

852

854 if (ctx == NULL)

855 {

857 return false;

858 }

859

860 if (state->conn->scram_server_key_binary)

861 {

863 }

864 else

865 {

867 state->key_length, ServerKey, errstr) < 0)

868 {

869

871 return false;

872 }

873 }

874

875

878 (uint8 *) state->client_first_message_bare,

879 strlen(state->client_first_message_bare)) < 0 ||

882 (uint8 *) state->server_first_message,

883 strlen(state->server_first_message)) < 0 ||

886 (uint8 *) state->client_final_message_without_proof,

887 strlen(state->client_final_message_without_proof)) < 0 ||

889 state->key_length) < 0)

890 {

893 return false;

894 }

895

897

898

899 if (memcmp(expected_ServerSignature, state->ServerSignature,

900 state->key_length) != 0)

901 *match = false;

902 else

903 *match = true;

904

905 return true;

906}

907

908

909

910

911

912

913

914char *

916{

917 char *prep_password;

920 char *result;

921

922

923

924

925

926

927

930 {

932 return NULL;

933 }

935 password = (const char *) prep_password;

936

937

939 {

940 *errstr = libpq_gettext("could not generate random salt");

941 free(prep_password);

942 return NULL;

943 }

944

948 errstr);

949

950 free(prep_password);

951

952 return result;

953}

int pg_b64_enc_len(int srclen)

int pg_b64_encode(const uint8 *src, int len, char *dst, int dstlen)

int pg_b64_dec_len(int srclen)

int pg_b64_decode(const char *src, int len, uint8 *dst, int dstlen)

int errmsg(const char *fmt,...)

static char * build_client_first_message(fe_scram_state *state)

const pg_fe_sasl_mech pg_scram_mech

static bool verify_server_signature(fe_scram_state *state, bool *match, const char **errstr)

static void * scram_init(PGconn *conn, const char *password, const char *sasl_mechanism)

static char * read_attr_value(char **input, char attr, PQExpBuffer errorMessage)

static void scram_free(void *opaq)

static bool read_server_first_message(fe_scram_state *state, char *input)

static bool scram_channel_bound(void *opaq)

char * pg_fe_scram_build_secret(const char *password, int iterations, const char **errstr)

static bool calculate_client_proof(fe_scram_state *state, const char *client_final_message_without_proof, uint8 *result, const char **errstr)

static SASLStatus scram_exchange(void *opaq, bool final, char *input, int inputlen, char **output, int *outputlen)

static char * build_client_final_message(fe_scram_state *state)

static bool read_server_final_message(fe_scram_state *state, char *input)

void libpq_append_error(PQExpBuffer errorMessage, const char *fmt,...)

char * pgtls_get_peer_certificate_hash(PGconn *conn, size_t *len)

Assert(PointerIsAligned(start, uint64))

pg_hmac_ctx * pg_hmac_create(pg_cryptohash_type type)

void pg_hmac_free(pg_hmac_ctx *ctx)

const char * pg_hmac_error(pg_hmac_ctx *ctx)

int pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len)

int pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)

int pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)

void libpq_append_conn_error(PGconn *conn, const char *fmt,...)

bool pg_strong_random(void *buf, size_t len)

void initPQExpBuffer(PQExpBuffer str)

int enlargePQExpBuffer(PQExpBuffer str, size_t needed)

void appendPQExpBuffer(PQExpBuffer str, const char *fmt,...)

void appendPQExpBufferChar(PQExpBuffer str, char ch)

void appendPQExpBufferStr(PQExpBuffer str, const char *data)

void termPQExpBuffer(PQExpBuffer str)

#define PQExpBufferDataBroken(buf)

pg_saslprep_rc pg_saslprep(const char *input, char **output)

int scram_ServerKey(const uint8 *salted_password, pg_cryptohash_type hash_type, int key_length, uint8 *result, const char **errstr)

int scram_SaltedPassword(const char *password, pg_cryptohash_type hash_type, int key_length, const uint8 *salt, int saltlen, int iterations, uint8 *result, const char **errstr)

char * scram_build_secret(pg_cryptohash_type hash_type, int key_length, const uint8 *salt, int saltlen, int iterations, const char *password, const char **errstr)

int scram_ClientKey(const uint8 *salted_password, pg_cryptohash_type hash_type, int key_length, uint8 *result, const char **errstr)

int scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length, uint8 *result, const char **errstr)

#define SCRAM_SHA_256_PLUS_NAME

#define SCRAM_RAW_NONCE_LEN

#define SCRAM_DEFAULT_SALT_LEN

#define SCRAM_MAX_KEY_LEN

#define SCRAM_SHA_256_KEY_LEN

char * client_final_message_without_proof

fe_scram_state_enum state

char * client_first_message_bare

char * server_final_message

char * server_first_message

pg_cryptohash_type hash_type

PQExpBufferData errorMessage