Added logsumexp to backend. (#6346) · keras-team/keras@7d52af6 (original) (raw)
`@@ -580,6 +580,41 @@ def step_function(x, states):
`
580
580
`assert_allclose(tf_last_output, th_last_output, atol=1e-04)
`
581
581
`assert_allclose(tf_outputs, th_outputs, atol=1e-04)
`
582
582
``
``
583
`+
@pytest.mark.parametrize('x_np,axis,keepdims', [
`
``
584
`+
(np.array([1.1, 0.8, 0.9]), 0, False),
`
``
585
`+
(np.array([[1.1, 0.8, 0.9]]), 0, False),
`
``
586
`+
(np.array([[1.1, 0.8, 0.9]]), 1, False),
`
``
587
`+
(np.array([[1.1, 0.8, 0.9]]), -1, False),
`
``
588
`+
(np.array([[1.1, 0.8, 0.9]]), 1, True),
`
``
589
`+
(np.array([[1.1], [1.2]]), 0, False),
`
``
590
`+
(np.array([[1.1], [1.2]]), 1, False),
`
``
591
`+
(np.array([[1.1], [1.2]]), -1, False),
`
``
592
`+
(np.array([[1.1], [1.2]]), -1, True),
`
``
593
`+
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), None, False),
`
``
594
`+
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), 0, False),
`
``
595
`+
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), 1, False),
`
``
596
`+
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), -1, False),
`
``
597
`+
])
`
``
598
`+
@pytest.mark.parametrize('K', [KTH, KTF], ids=["KTH", "KTF"])
`
``
599
`+
def test_logsumexp(self, x_np, axis, keepdims, K):
`
``
600
`+
'''
`
``
601
`+
Check if K.logsumexp works properly for values close to one.
`
``
602
`+
'''
`
``
603
`+
x = K.variable(x_np)
`
``
604
`+
assert_allclose(K.eval(K.logsumexp(x, axis=axis, keepdims=keepdims)),
`
``
605
`+
np.log(np.sum(np.exp(x_np), axis=axis, keepdims=keepdims)),
`
``
606
`+
rtol=1e-5)
`
``
607
+
``
608
`+
@pytest.mark.parametrize('K', [KTH, KTF], ids=["KTH", "KTF"])
`
``
609
`+
def test_logsumexp_optim(self, K):
`
``
610
`+
'''
`
``
611
`+
Check if optimization works.
`
``
612
`+
'''
`
``
613
`+
x_np = np.array([1e+4, 1e-4])
`
``
614
`+
assert_allclose(K.eval(K.logsumexp(K.variable(x_np), axis=0)),
`
``
615
`+
1e4,
`
``
616
`+
rtol=1e-5)
`
``
617
+
583
618
`def test_switch(self):
`
584
619
`val = np.random.random()
`
585
620
`xth = KTH.variable(val)
`