torch_frame.nn.models.Trompt — pytorch-frame documentation (original) (raw)
class Trompt(channels: int, out_channels: int, num_prompts: int, num_layers: int, col_stats: dict[str, dict[StatType, Any]], col_names_dict: dict[torch_frame.stype, list[str]], stype_encoder_dicts: list[dict[torch_frame.stype, StypeEncoder]] | None = None)[source]
Bases: Module
The Trompt model introduced in the“Trompt: Towards a Better Deep Neural Network for Tabular Data” paper.
Parameters:
- channels (int) – Hidden channel dimensionality
- out_channels (int) – Output channels dimensionality
- num_prompts (int) – Number of prompt columns.
- num_layers (int, optional) – Number of
TromptConv
layers. (default:6
) - col_stats (Dict[str,Dict[torch_frame.data.stats.StatType,Any]]) – A dictionary that maps column name into stats. Available as
dataset.col_stats
. - col_names_dict (Dict[torch_frame.stype, List[str]]) – A dictionary that maps stype to a list of column names. The column names are sorted based on the ordering that appear in
tensor_frame.feat_dict
. Available astensor_frame.col_names_dict
. - stype_encoder_dicts – (list[dict[torch_frame.stype,torch_frame.nn.encoder.StypeEncoder]], optional): A list of
num_layers
dictionaries that each dictionary maps stypes into their stype encoders. (default: None, will callEmbeddingEncoder()
for categorical feature andLinearEncoder()
for numerical feature)
forward(tf: TensorFrame) → Tensor[source]
Transforming TensorFrame
object into a series of output predictions at each layer. Used during training to compute layer-wise loss.
Parameters:
tf (torch_frame.TensorFrame
) – Input TensorFrame
object.
Returns:
Output predictions stacked across layers. The
shape is [batch_size, num_layers, out_channels]
.
Return type: