LPPool3d — PyTorch 2.7 documentation (original) (raw)
class torch.nn.LPPool3d(norm_type, kernel_size, stride=None, ceil_mode=False)[source][source]¶
Applies a 3D power-average pooling over an input signal composed of several input planes.
On each window, the function computed is:
f(X)=∑x∈Xxppf(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
- At p = ∞\infty, one gets Max Pooling
- At p = 1, one gets Sum Pooling (which is proportional to average pooling)
The parameters kernel_size
, stride
can either be:
- a single
int
– in which case the same value is used for the height, width and depth dimension- a
tuple
of three ints – in which case, the first int is used for the depth dimension, the second int for the height dimension and the third int for the width dimension
Note
If the sum to the power of p is zero, the gradient of this function is not defined. This implementation will set the gradient to zero in this case.
Parameters
- kernel_size (Union[_int,_ tuple[_int,_ int, int] ]) – the size of the window
- stride (Union[_int,_ tuple[_int,_ int, int] ]) – the stride of the window. Default value is
kernel_size
- ceil_mode (bool) – when True, will use ceil instead of floor to compute the output shape
Shape:
- Input: (N,C,Din,Hin,Win)(N, C, D_{in}, H_{in}, W_{in}) or (C,Din,Hin,Win)(C, D_{in}, H_{in}, W_{in}).
- Output: (N,C,Dout,Hout,Wout)(N, C, D_{out}, H_{out}, W_{out}) or(C,Dout,Hout,Wout)(C, D_{out}, H_{out}, W_{out}), where
Dout=⌊Din−kernel_size[0]stride[0]+1⌋D_{out} = \left\lfloor\frac{D_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
Hout=⌊Hin−kernel_size[1]stride[1]+1⌋H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
Wout=⌊Win−kernel_size[2]stride[2]+1⌋W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
Examples:
power-2 pool of square window of size=3, stride=2
m = nn.LPPool3d(2, 3, stride=2)
pool of non-square window of power 1.2
m = nn.LPPool3d(1.2, (3, 2, 2), stride=(2, 1, 2)) input = torch.randn(20, 16, 50, 44, 31) output = m(input)