Outer Product between a Tensor and a Vector
This is an example of how to use the csdl outer function to compute the outer product between a tensor and a vector.
from csdl_om import Simulatorfrom csdl import Modelimport csdlimport numpy as np
class ExampleTensorVector(Model):
def define(self):
m = 3 n = 4 p = 5
# Shape of the vectors vec_shape = (m, )
# Shape of the tensors ten_shape = (m, n, p)
# Values for the two vectors vec1 = np.arange(m)
# Number of elements in the tensors num_ten_elements = np.prod(ten_shape)
# Values for the two tensors ten1 = np.arange(num_ten_elements).reshape(ten_shape)
# Adding the vector and tensor to csdl vec1 = self.declare_variable('vec1', val=vec1)
ten1 = self.declare_variable('ten1', val=ten1)
# Tensor-Vector Outer Product specifying the first axis for Vector and Tensor self.register_output('TenVecOuter', csdl.outer(ten1, vec1))
sim = Simulator(ExampleTensorVector())sim.run()
print('vec1', sim['vec1'].shape)print(sim['vec1'])print('ten1', sim['ten1'].shape)print(sim['ten1'])print('TenVecOuter', sim['TenVecOuter'].shape)print(sim['TenVecOuter'])
[0. 1. 2.]ten1 (3, 4, 5)[[[ 0. 1. 2. 3. 4.] [ 5. 6. 7. 8. 9.] [10. 11. 12. 13. 14.] [15. 16. 17. 18. 19.]]
[[20. 21. 22. 23. 24.] [25. 26. 27. 28. 29.] [30. 31. 32. 33. 34.] [35. 36. 37. 38. 39.]]
[[40. 41. 42. 43. 44.] [45. 46. 47. 48. 49.] [50. 51. 52. 53. 54.] [55. 56. 57. 58. 59.]]]TenVecOuter (3, 4, 5, 3)[[[[ 0. 0. 0.] [ 0. 1. 2.] [ 0. 2. 4.] [ 0. 3. 6.] [ 0. 4. 8.]]
[[ 0. 5. 10.] [ 0. 6. 12.] [ 0. 7. 14.] [ 0. 8. 16.] [ 0. 9. 18.]]
[[ 0. 10. 20.] [ 0. 11. 22.] [ 0. 12. 24.] [ 0. 13. 26.] [ 0. 14. 28.]]
[[ 0. 15. 30.] [ 0. 16. 32.] [ 0. 17. 34.] [ 0. 18. 36.] [ 0. 19. 38.]]]
[[[ 0. 20. 40.] [ 0. 21. 42.] [ 0. 22. 44.] [ 0. 23. 46.] [ 0. 24. 48.]]
[[ 0. 25. 50.] [ 0. 26. 52.] [ 0. 27. 54.] [ 0. 28. 56.] [ 0. 29. 58.]]
[[ 0. 30. 60.] [ 0. 31. 62.] [ 0. 32. 64.] [ 0. 33. 66.] [ 0. 34. 68.]]
[[ 0. 35. 70.] [ 0. 36. 72.] [ 0. 37. 74.] [ 0. 38. 76.] [ 0. 39. 78.]]]
[[[ 0. 40. 80.] [ 0. 41. 82.] [ 0. 42. 84.] [ 0. 43. 86.] [ 0. 44. 88.]]
[[ 0. 45. 90.] [ 0. 46. 92.] [ 0. 47. 94.] [ 0. 48. 96.] [ 0. 49. 98.]]
[[ 0. 50. 100.] [ 0. 51. 102.] [ 0. 52. 104.] [ 0. 53. 106.] [ 0. 54. 108.]]
[[ 0. 55. 110.] [ 0. 56. 112.] [ 0. 57. 114.] [ 0. 58. 116.] [ 0. 59. 118.]]]]