Added logsumexp to backend. (#6346) · keras-team/keras@7d52af6 (original) (raw)

`@@ -580,6 +580,41 @@ def step_function(x, states):

`

580

580

`assert_allclose(tf_last_output, th_last_output, atol=1e-04)

`

581

581

`assert_allclose(tf_outputs, th_outputs, atol=1e-04)

`

582

582

``

``

583

`+

@pytest.mark.parametrize('x_np,axis,keepdims', [

`

``

584

`+

(np.array([1.1, 0.8, 0.9]), 0, False),

`

``

585

`+

(np.array([[1.1, 0.8, 0.9]]), 0, False),

`

``

586

`+

(np.array([[1.1, 0.8, 0.9]]), 1, False),

`

``

587

`+

(np.array([[1.1, 0.8, 0.9]]), -1, False),

`

``

588

`+

(np.array([[1.1, 0.8, 0.9]]), 1, True),

`

``

589

`+

(np.array([[1.1], [1.2]]), 0, False),

`

``

590

`+

(np.array([[1.1], [1.2]]), 1, False),

`

``

591

`+

(np.array([[1.1], [1.2]]), -1, False),

`

``

592

`+

(np.array([[1.1], [1.2]]), -1, True),

`

``

593

`+

(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), None, False),

`

``

594

`+

(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), 0, False),

`

``

595

`+

(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), 1, False),

`

``

596

`+

(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), -1, False),

`

``

597

`+

])

`

``

598

`+

@pytest.mark.parametrize('K', [KTH, KTF], ids=["KTH", "KTF"])

`

``

599

`+

def test_logsumexp(self, x_np, axis, keepdims, K):

`

``

600

`+

'''

`

``

601

`+

Check if K.logsumexp works properly for values close to one.

`

``

602

`+

'''

`

``

603

`+

x = K.variable(x_np)

`

``

604

`+

assert_allclose(K.eval(K.logsumexp(x, axis=axis, keepdims=keepdims)),

`

``

605

`+

np.log(np.sum(np.exp(x_np), axis=axis, keepdims=keepdims)),

`

``

606

`+

rtol=1e-5)

`

``

607

+

``

608

`+

@pytest.mark.parametrize('K', [KTH, KTF], ids=["KTH", "KTF"])

`

``

609

`+

def test_logsumexp_optim(self, K):

`

``

610

`+

'''

`

``

611

`+

Check if optimization works.

`

``

612

`+

'''

`

``

613

`+

x_np = np.array([1e+4, 1e-4])

`

``

614

`+

assert_allclose(K.eval(K.logsumexp(K.variable(x_np), axis=0)),

`

``

615

`+

1e4,

`

``

616

`+

rtol=1e-5)

`

``

617

+

583

618

`def test_switch(self):

`

584

619

`val = np.random.random()

`

585

620

`xth = KTH.variable(val)

`