Skip to content

Commit

Permalink
[feat] add right and left jacobian for So3 (#2509)
Browse files Browse the repository at this point in the history
* add right and left jacobian for So3

* revert wrong test, update tests

* review changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add noqa

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cjpurackal and pre-commit-ci[bot] committed Aug 16, 2023
1 parent e83f048 commit daf6ce7
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 3 deletions.
5 changes: 5 additions & 0 deletions docs/source/geometry.conversions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ Quaternion
.. autofunction:: quaternion_exp_to_log
.. autofunction:: normalize_quaternion

Vector
------

.. autofunction:: vector_to_skew_symmetric_matrix

Rotation Matrix
---------------

Expand Down
37 changes: 36 additions & 1 deletion kornia/geometry/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn.functional as F

from kornia.constants import pi
from kornia.core import Tensor, concatenate, pad, stack, tensor, where
from kornia.core import Tensor, concatenate, pad, stack, tensor, where, zeros_like
from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
from kornia.utils import deprecated
from kornia.utils.helpers import _torch_inverse_cast
Expand Down Expand Up @@ -54,6 +54,7 @@
"camtoworld_graphics_to_vision_Rt",
"camtoworld_vision_to_graphics_Rt",
"ARKitQTVecs_to_ColmapQTVecs",
"vector_to_skew_symmetric_matrix",
]


Expand Down Expand Up @@ -1478,3 +1479,37 @@ def ARKitQTVecs_to_ColmapQTVecs(qvec: Tensor, tvec: Tensor) -> tuple[Tensor, Ten
t_colmap = t_colmap.reshape(-1, 3, 1)
q_colmap = rotation_matrix_to_quaternion(R_colmap.contiguous())
return q_colmap, t_colmap


def vector_to_skew_symmetric_matrix(vec: Tensor) -> Tensor:
r"""Converts a vector to a skew symmetric matrix.
A vector :math:`(v1, v2, v3)` has a corresponding skew-symmetric matrix, which is of the form:
.. math::
\begin{bmatrix} 0 & -v3 & v2 \\
v3 & 0 & -v1 \\
-v2 & v1 & 0\end{bmatrix}
Args:
x: tensor of shape :math:`(B, 3)`.
Returns:
tensor of shape :math:`(B, 3, 3)`.
Example:
>>> vec = torch.tensor([1.0, 2.0, 3.0])
>>> vector_to_skew_symmetric_matrix(vec)
tensor([[ 0., -3., 2.],
[ 3., 0., -1.],
[-2., 1., 0.]])
"""
# KORNIA_CHECK_SHAPE(vec, ["B", "3"])
if vec.shape[-1] != 3 or len(vec.shape) > 2:
raise ValueError(f"Input vector must be of shape (B, 3) or (3,). " f"Got {vec.shape}")
v1, v2, v3 = vec[..., 0], vec[..., 1], vec[..., 2]
zeros = zeros_like(v1)
skew_symmetric_matrix = stack(
[stack([zeros, -v3, v2], dim=-1), stack([v3, zeros, -v1], dim=-1), stack([-v2, v1, zeros], dim=-1)], dim=-2
)
return skew_symmetric_matrix
65 changes: 63 additions & 2 deletions kornia/geometry/liegroup/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# https://github.com/strasdat/Sophus/blob/master/sympy/sophus/so3.py
from __future__ import annotations

from kornia.core import Device, Dtype, Module, Tensor, concatenate, stack, tensor, where, zeros, zeros_like
from kornia.core import Device, Dtype, Module, Tensor, concatenate, eye, stack, tensor, where, zeros, zeros_like
from kornia.core.check import KORNIA_CHECK_TYPE
from kornia.geometry.conversions import vector_to_skew_symmetric_matrix
from kornia.geometry.linalg import batched_dot_product
from kornia.geometry.quaternion import Quaternion
from kornia.geometry.vector import Vector3
Expand Down Expand Up @@ -92,7 +93,7 @@ def exp(v: Tensor) -> So3:
Example:
>>> v = torch.zeros((2, 3))
>>> s = So3.identity().exp(v)
>>> s = So3.exp(v)
>>> s
Parameter containing:
tensor([[1., 0., 0., 0.],
Expand Down Expand Up @@ -331,3 +332,63 @@ def adjoint(self) -> Tensor:
[0., 0., 1.]], grad_fn=<StackBackward0>)
"""
return self.matrix()

@staticmethod
def right_jacobian(vec: Tensor) -> Tensor:
"""Computes the right Jacobian of So3.
Args:
vec: the input point of shape :math:`(B, 3)`.
Example:
>>> vec = torch.tensor([1., 2., 3.])
>>> So3.right_jacobian(vec)
tensor([[-0.0687, 0.5556, -0.0141],
[-0.2267, 0.1779, 0.6236],
[ 0.5074, 0.3629, 0.5890]])
"""
# KORNIA_CHECK_SHAPE(vec, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
R_skew = vector_to_skew_symmetric_matrix(vec)
theta = vec.norm(dim=-1, keepdim=True)[..., None]
I = eye(3, device=vec.device, dtype=vec.dtype) # noqa: E741
Jr = I - ((1 - theta.cos()) / theta**2) * R_skew + ((theta - theta.sin()) / theta**3) * (R_skew @ R_skew)
return Jr

@staticmethod
def Jr(vec: Tensor) -> Tensor:
"""Alias for right jacobian.
Args:
vec: the input point of shape :math:`(B, 3)`.
"""
return So3.right_jacobian(vec)

@staticmethod
def left_jacobian(vec: Tensor) -> Tensor:
"""Computes the left Jacobian of So3.
Args:
vec: the input point of shape :math:`(B, 3)`.
Example:
>>> vec = torch.tensor([1., 2., 3.])
>>> So3.left_jacobian(vec)
tensor([[-0.0687, -0.2267, 0.5074],
[ 0.5556, 0.1779, 0.3629],
[-0.0141, 0.6236, 0.5890]])
"""
# KORNIA_CHECK_SHAPE(vec, ["B", "3"]) # FIXME: resolve shape bugs. @edgarriba
R_skew = vector_to_skew_symmetric_matrix(vec)
theta = vec.norm(dim=-1, keepdim=True)[..., None]
I = eye(3, device=vec.device, dtype=vec.dtype) # noqa: E741
Jl = I + ((1 - theta.cos()) / theta**2) * R_skew + ((theta - theta.sin()) / theta**3) * (R_skew @ R_skew)
return Jl

@staticmethod
def Jl(vec: Tensor) -> Tensor:
"""Alias for left jacobian.
Args:
vec: the input point of shape :math:`(B, 3)`.
"""
return So3.left_jacobian(vec)
23 changes: 23 additions & 0 deletions test/geometry/liegroup/test_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,26 @@ def test_random(self, device, dtype, batch_size):
s_in_s = s.inverse() * s
i = So3.identity(batch_size=batch_size, device=device, dtype=dtype)
self.assert_close(s_in_s.q.data, i.q.data)

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_right_jacobian(self, device, dtype, batch_size):
vec = self._make_rand_data(device, dtype, batch_size, dims=3)
Jr = So3.right_jacobian(vec)
I = torch.eye(3, device=device, dtype=dtype).expand_as(Jr) # noqa: E741
self.assert_close(vec[..., None], Jr @ vec[..., None])
self.assert_close(Jr.transpose(-1, -2) @ Jr, I, atol=0.1, rtol=0.1)

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_left_jacobian(self, device, dtype, batch_size):
vec = self._make_rand_data(device, dtype, batch_size, dims=3)
Jl = So3.left_jacobian(vec)
I = torch.eye(3, device=device, dtype=dtype).expand_as(Jl) # noqa: E741
self.assert_close(vec[..., None], Jl @ vec[..., None])
self.assert_close(Jl.transpose(-1, -2) @ Jl, I, atol=0.1, rtol=0.1)

@pytest.mark.parametrize("batch_size", (None, 1, 2, 5))
def test_right_left_jacobian(self, device, dtype, batch_size):
vec = self._make_rand_data(device, dtype, batch_size, dims=3)
Jr = So3.right_jacobian(vec)
Jl = So3.left_jacobian(vec)
self.assert_close(Jl, Jr.transpose(-1, -2))
21 changes: 21 additions & 0 deletions test/geometry/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,3 +1338,24 @@ def test_values(self, device, dtype):

# out = [tf3.euler.quat2euler((qw[i], qx[i], qy[i], qz[i])) for i in range(num_samples)]
# out = torch.tensor(out, device=device, dtype=dtype)


@pytest.mark.parametrize('batch_size', (None, 1, 2, 5))
def test_vector_to_skew_symmetric_matrix(batch_size, device, dtype):
if batch_size is None:
vector = torch.rand(3, device=device, dtype=dtype)
else:
vector = torch.rand((batch_size, 3), device=device, dtype=dtype)
skew_symmetric_matrix = kornia.geometry.conversions.vector_to_skew_symmetric_matrix(vector)
assert skew_symmetric_matrix.shape[-1] == 3
assert skew_symmetric_matrix.shape[-2] == 3
z = torch.zeros_like(vector[..., 0])
assert_close(skew_symmetric_matrix[..., 0, 0], z)
assert_close(skew_symmetric_matrix[..., 1, 1], z)
assert_close(skew_symmetric_matrix[..., 2, 2], z)
assert_close(skew_symmetric_matrix[..., 0, 1], -vector[..., 2])
assert_close(skew_symmetric_matrix[..., 1, 0], vector[..., 2])
assert_close(skew_symmetric_matrix[..., 0, 2], vector[..., 1])
assert_close(skew_symmetric_matrix[..., 2, 0], -vector[..., 1])
assert_close(skew_symmetric_matrix[..., 1, 2], -vector[..., 0])
assert_close(skew_symmetric_matrix[..., 2, 1], vector[..., 0])

0 comments on commit daf6ce7

Please sign in to comment.