Replies: 1 comment
-
This question looks pretty specific to OTT; you may have better luck asking at https://github.com/ott-jax/ott/. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi there, I am trying to reproduce the Figure 1 in Klein et al. (2024). I have trouble implementing the fourth cost function, where$h(z) = \frac{1}{2} |z|_2^2 + \gamma \frac{1}{2}|b^\perp z|_2^2$ . I am using

$b=[0,1]$ the points should be transported parallel to y-axis but they are not. I also tried using
orthogonal_regularizer = regularizers.Quadratic(A, is_factor=True, is_complement=True)
as regularizer that I then use for the cost function
l2_b_cost = costs.RegTICost(orthogonal_regularizer, lam=1000)
which I then pass to
map_l2_b = entropic_map(xs, y, cost_fn = l2_b_cost)
.The entropic_maps function is taken from a tutorial.
The resulting transportation looks like this (red arrow is direction of b, orange points are the transported points)
which does not penalize the orthogonal transportation as expected, i.e. for
orthogonal_regularizer = regularizers.Orthogonal(f=l2_regularizer, A=P)
, but that did not work either.I was wondering where my mistake is and what I can do for a proper implementation.
Code to reproduce:
Beta Was this translation helpful? Give feedback.
All reactions