Cross Product between Two Tensors
This is an example of how to use the csdl.cross() function to compute the cross product between two tensors.
from csdl_om import Simulatorfrom csdl import Modelimport csdlimport numpy as np
class ExampleTensorTensor(Model):
def define(self): # Creating two tensors shape = (2, 5, 4, 3) num_elements = np.prod(shape)
tenval1 = np.arange(num_elements).reshape(shape) tenval2 = np.arange(num_elements).reshape(shape) + 6
ten1 = self.declare_variable('ten1', val=tenval1) ten2 = self.declare_variable('ten2', val=tenval2)
# Tensor-Tensor Dot Product specifying the last axis self.register_output('TenTenCross', csdl.cross(ten1, ten2, axis=3))
sim = Simulator(ExampleTensorTensor())sim.run()
print('ten1', sim['ten1'].shape)print(sim['ten1'])print('ten2', sim['ten2'].shape)print(sim['ten2'])print('TenTenCross', sim['TenTenCross'].shape)print(sim['TenTenCross'])
[[[[ 0. 1. 2.] [ 3. 4. 5.] [ 6. 7. 8.] [ 9. 10. 11.]]
[[ 12. 13. 14.] [ 15. 16. 17.] [ 18. 19. 20.] [ 21. 22. 23.]]
[[ 24. 25. 26.] [ 27. 28. 29.] [ 30. 31. 32.] [ 33. 34. 35.]]
[[ 36. 37. 38.] [ 39. 40. 41.] [ 42. 43. 44.] [ 45. 46. 47.]]
[[ 48. 49. 50.] [ 51. 52. 53.] [ 54. 55. 56.] [ 57. 58. 59.]]]
[[[ 60. 61. 62.] [ 63. 64. 65.] [ 66. 67. 68.] [ 69. 70. 71.]]
[[ 72. 73. 74.] [ 75. 76. 77.] [ 78. 79. 80.] [ 81. 82. 83.]]
[[ 84. 85. 86.] [ 87. 88. 89.] [ 90. 91. 92.] [ 93. 94. 95.]]
[[ 96. 97. 98.] [ 99. 100. 101.] [102. 103. 104.] [105. 106. 107.]]
[[108. 109. 110.] [111. 112. 113.] [114. 115. 116.] [117. 118. 119.]]]]ten2 (2, 5, 4, 3)[[[[ 6. 7. 8.] [ 9. 10. 11.] [ 12. 13. 14.] [ 15. 16. 17.]]
[[ 18. 19. 20.] [ 21. 22. 23.] [ 24. 25. 26.] [ 27. 28. 29.]]
[[ 30. 31. 32.] [ 33. 34. 35.] [ 36. 37. 38.] [ 39. 40. 41.]]
[[ 42. 43. 44.] [ 45. 46. 47.] [ 48. 49. 50.] [ 51. 52. 53.]]
[[ 54. 55. 56.] [ 57. 58. 59.] [ 60. 61. 62.] [ 63. 64. 65.]]]
[[[ 66. 67. 68.] [ 69. 70. 71.] [ 72. 73. 74.] [ 75. 76. 77.]]
[[ 78. 79. 80.] [ 81. 82. 83.] [ 84. 85. 86.] [ 87. 88. 89.]]
[[ 90. 91. 92.] [ 93. 94. 95.] [ 96. 97. 98.] [ 99. 100. 101.]]
[[102. 103. 104.] [105. 106. 107.] [108. 109. 110.] [111. 112. 113.]]
[[114. 115. 116.] [117. 118. 119.] [120. 121. 122.] [123. 124. 125.]]]]TenTenCross (2, 5, 4, 3)[[[[-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.]]
[[-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.]]
[[-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.]]
[[-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.]]
[[-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.]]]
[[[-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.]]
[[-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.]]
[[-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.]]
[[-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.]]
[[-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.] [-6. 12. -6.]]]]