GLPN (original) (raw)

PyTorch

This is a recently introduced model so the API hasn’t been tested extensively. There may be some bugs or slight breaking changes to fix it in the future. If you see something strange, file a Github Issue.

Overview

The GLPN model was proposed in Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. GLPN combines SegFormer’s hierarchical mix-Transformer with a lightweight decoder for monocular depth estimation. The proposed decoder shows better performance than the previously proposed decoders, with considerably less computational complexity.

The abstract from the paper is the following:

Depth estimation from a single image is an important task that can be applied to various fields in computer vision, and has grown rapidly with the development of convolutional neural networks. In this paper, we propose a novel structure and training strategy for monocular depth estimation to further improve the prediction accuracy of the network. We deploy a hierarchical transformer encoder to capture and convey the global context, and design a lightweight yet powerful decoder to generate an estimated depth map while considering local connectivity. By constructing connected paths between multi-scale local features and the global decoding stream with our proposed selective feature fusion module, the network can integrate both representations and recover fine details. In addition, the proposed decoder shows better performance than the previously proposed decoders, with considerably less computational complexity. Furthermore, we improve the depth-specific augmentation method by utilizing an important observation in depth estimation to enhance the model. Our network achieves state-of-the-art performance over the challenging depth dataset NYU Depth V2. Extensive experiments have been conducted to validate and show the effectiveness of the proposed approach. Finally, our model shows better generalisation ability and robustness than other comparative models.

drawing Summary of the approach. Taken from the original paper.

This model was contributed by nielsr. The original code can be found here.

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GLPN.

GLPNConfig

class transformers.GLPNConfig

< source >

( num_channels = 3 num_encoder_blocks = 4 depths = [2, 2, 2, 2] sr_ratios = [8, 4, 2, 1] hidden_sizes = [32, 64, 160, 256] patch_sizes = [7, 3, 3, 3] strides = [4, 2, 2, 2] num_attention_heads = [1, 2, 5, 8] mlp_ratios = [4, 4, 4, 4] hidden_act = 'gelu' hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 initializer_range = 0.02 drop_path_rate = 0.1 layer_norm_eps = 1e-06 decoder_hidden_size = 64 max_depth = 10 head_in_index = -1 **kwargs )

Parameters

This is the configuration class to store the configuration of a GLPNModel. It is used to instantiate an GLPN model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the GLPNvinvino02/glpn-kitti architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

from transformers import GLPNModel, GLPNConfig

configuration = GLPNConfig()

model = GLPNModel(configuration)

configuration = model.config

GLPNFeatureExtractor

Preprocess an image or a batch of images.

GLPNImageProcessor

class transformers.GLPNImageProcessor

< source >

( do_resize: bool = True size_divisor: int = 32 resample = <Resampling.BILINEAR: 2> do_rescale: bool = True **kwargs )

Parameters

Constructs a GLPN image processor.

preprocess

< source >

( images: typing.Union[ForwardRef('PIL.Image.Image'), transformers.utils.generic.TensorType, typing.List[ForwardRef('PIL.Image.Image')], typing.List[transformers.utils.generic.TensorType]] do_resize: typing.Optional[bool] = None size_divisor: typing.Optional[int] = None resample = None do_rescale: typing.Optional[bool] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None data_format: ChannelDimension = <ChannelDimension.FIRST: 'channels_first'> input_data_format: typing.Union[str, transformers.image_utils.ChannelDimension, NoneType] = None )

Parameters

Preprocess the given images.

GLPNModel

class transformers.GLPNModel

< source >

( config )

Parameters

The bare Glpn Model outputting raw hidden-states without any specific head on top.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: FloatTensor output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) β†’ transformers.modeling_outputs.BaseModelOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.BaseModelOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (GLPNConfig) and inputs.

The GLPNModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

GLPNForDepthEstimation

class transformers.GLPNForDepthEstimation

< source >

( config )

Parameters

GLPN Model transformer with a lightweight depth estimation head on top e.g. for KITTI, NYUv2.

This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< source >

( pixel_values: FloatTensor labels: typing.Optional[torch.FloatTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) β†’ transformers.modeling_outputs.DepthEstimatorOutput or tuple(torch.FloatTensor)

Parameters

A transformers.modeling_outputs.DepthEstimatorOutput or a tuple oftorch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (GLPNConfig) and inputs.

The GLPNForDepthEstimation forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

from transformers import AutoImageProcessor, GLPNForDepthEstimation import torch import numpy as np from PIL import Image import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("vinvino02/glpn-kitti") model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-kitti")

inputs = image_processor(images=image, return_tensors="pt")

with torch.no_grad(): ... outputs = model(**inputs)

post_processed_output = image_processor.post_process_depth_estimation( ... outputs, ... target_sizes=[(image.height, image.width)], ... )

predicted_depth = post_processed_output[0]["predicted_depth"] depth = predicted_depth * 255 / predicted_depth.max() depth = depth.detach().cpu().numpy() depth = Image.fromarray(depth.astype("uint8"))

< > Update on GitHub