torch_frame.nn.models.Trompt — pytorch-frame documentation (original) (raw)

pytorch-frame

class Trompt(channels: int, out_channels: int, num_prompts: int, num_layers: int, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dicts: list[dict[torch_frame.stype, StypeEncoder]] | None = None)[source]

Bases: Module

The Trompt model introduced in the“Trompt: Towards a Better Deep Neural Network for Tabular Data” paper.

Parameters:

forward(tf: TensorFrame) → Tensor[source]

Transforming TensorFrame object into a series of output predictions at each layer. Used during training to compute layer-wise loss.

Parameters:

tf (torch_frame.TensorFrame) – Input TensorFrame object.

Returns:

Output predictions stacked across layers. The

shape is [batch_size, num_layers, out_channels].

Return type:

torch.Tensor