Loss Functions

Classes:

EMDLoss(*args, **kwargs)

Calculates the energy mover's distance between two batches of jets differentiably as a convex optimization problem either through the linear programming library cvxpy or by converting it to a quadratic programming problem and using the qpth library.

class jetnet.losses.EMDLoss(*args: Any, **kwargs: Any)

Calculates the energy mover’s distance between two batches of jets differentiably as a convex optimization problem either through the linear programming library cvxpy or by converting it to a quadratic programming problem and using the qpth library. cvxpy is marginally more accurate but qpth is significantly faster so defaults to qpth.

JetNet must be installed with the extra option pip install jetnet[emdloss] to use this.

Note: PyTorch <= 1.9 has a bug which will cause this to fail for >= 32 particles. This PR should fix this from 1.10 onwards https://github.com/pytorch/pytorch/pull/61815.

Parameters
  • method (str) – ‘cvxpy’ or ‘qpth’. Defaults to ‘qpth’.

  • num_particles (int) – number of particles per jet - only needs to be specified if method is ‘cvxpy’.

  • qpth_form (str) – ‘L2’ or ‘QP’. Defaults to ‘L2’.

  • qpth_l2_strength (float) – regularization parameter for ‘L2’ qp form. Defaults to 0.0001.

  • device (str) – ‘cpu’ or ‘cuda’. Defaults to ‘cpu’.

Methods:

forward(jets1, jets2[, return_flows])

Calculate EMD between jets1 and jets2.

forward(jets1: Tensor, jets2: Tensor, return_flows: bool = False) Tensor | tuple[Tensor, Tensor]

Calculate EMD between jets1 and jets2.

Parameters
  • jets1 (Tensor) – tensor of shape [num_jets, num_particles, num_features], with features in order [eta, phi, pt].

  • jets2 (Tensor) – tensor of same format as jets1.

  • return_flows (bool) – return energy flows between particles in each jet. Defaults to False.

Returns

  • Tensor: EMD scores tensor of shape [num_jets].

  • Tensor Optional, if return_flows is True: tensor of flows between particles of shape [num_jets, num_particles, num_particles].

Return type

Union[Tensor, Tuple[Tensor, Tensor]]