.. _l-onnx-doc-Where: ===== Where ===== .. contents:: :local: .. _l-onnx-op-where-16: Where - 16 ========== **Version** * **name**: `Where (GitHub) `_ * **domain**: **main** * **since_version**: **16** * **function**: False * **support_level**: SupportType.COMMON * **shape inference**: True This version of the operator has been available **since version 16**. **Summary** Return elements, either from X or Y, depending on condition. Where behaves like [numpy.where](https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html) with three parameters. This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check `Broadcasting in ONNX `_. **History** - Version 16 adds bfloat16 to the types allowed (for the second and third parameter). **Inputs** * **condition** (heterogeneous) - **B**: When True (nonzero), yield X, otherwise yield Y * **X** (heterogeneous) - **T**: values selected at indices where condition is True * **Y** (heterogeneous) - **T**: values selected at indices where condition is False **Outputs** * **output** (heterogeneous) - **T**: Tensor of shape equal to the broadcasted shape of condition, X, and Y. **Type Constraints** * **B** in ( tensor(bool) ): Constrain to boolean tensors. * **T** in ( tensor(bfloat16), tensor(bool), tensor(complex128), tensor(complex64), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) ): Constrain input and output types to all tensor types (including bfloat). **Examples** **default** :: node = onnx.helper.make_node( "Where", inputs=["condition", "x", "y"], outputs=["z"], ) condition = np.array([[1, 0], [1, 1]], dtype=bool) x = np.array([[1, 2], [3, 4]], dtype=np.float32) y = np.array([[9, 8], [7, 6]], dtype=np.float32) z = np.where(condition, x, y) # expected output [[1, 8], [3, 4]] expect(node, inputs=[condition, x, y], outputs=[z], name="test_where_example") **_long** :: node = onnx.helper.make_node( "Where", inputs=["condition", "x", "y"], outputs=["z"], ) condition = np.array([[1, 0], [1, 1]], dtype=bool) x = np.array([[1, 2], [3, 4]], dtype=np.int64) y = np.array([[9, 8], [7, 6]], dtype=np.int64) z = np.where(condition, x, y) # expected output [[1, 8], [3, 4]] expect( node, inputs=[condition, x, y], outputs=[z], name="test_where_long_example" ) **Differences** .. raw:: html
00Return elements, either from X or Y, depending on condition.Return elements, either from X or Y, depending on condition.
11Where behaves likeWhere behaves like
22[numpy.where](https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html)[numpy.where](https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html)
33with three parameters.with three parameters.
44
55This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check Broadcasting in ONNX _.This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check Broadcasting in ONNX _.
66
7**History**
8- Version 16 adds bfloat16 to the types allowed (for the second and third parameter).
9
710**Inputs****Inputs**
811
912* **condition** (heterogeneous) - **B**:* **condition** (heterogeneous) - **B**:
1013 When True (nonzero), yield X, otherwise yield Y When True (nonzero), yield X, otherwise yield Y
1114* **X** (heterogeneous) - **T**:* **X** (heterogeneous) - **T**:
1215 values selected at indices where condition is True values selected at indices where condition is True
1316* **Y** (heterogeneous) - **T**:* **Y** (heterogeneous) - **T**:
1417 values selected at indices where condition is False values selected at indices where condition is False
1518
1619**Outputs****Outputs**
1720
1821* **output** (heterogeneous) - **T**:* **output** (heterogeneous) - **T**:
1922 Tensor of shape equal to the broadcasted shape of condition, X, and Tensor of shape equal to the broadcasted shape of condition, X, and
2023 Y. Y.
2124
2225**Type Constraints****Type Constraints**
2326
2427* **B** in (* **B** in (
2528 tensor(bool) tensor(bool)
2629 ): ):
2730 Constrain to boolean tensors. Constrain to boolean tensors.
2831* **T** in (* **T** in (
32 tensor(bfloat16),
2933 tensor(bool), tensor(bool),
3034 tensor(complex128), tensor(complex128),
3135 tensor(complex64), tensor(complex64),
3236 tensor(double), tensor(double),
3337 tensor(float), tensor(float),
3438 tensor(float16), tensor(float16),
3539 tensor(int16), tensor(int16),
3640 tensor(int32), tensor(int32),
3741 tensor(int64), tensor(int64),
3842 tensor(int8), tensor(int8),
3943 tensor(string), tensor(string),
4044 tensor(uint16), tensor(uint16),
4145 tensor(uint32), tensor(uint32),
4246 tensor(uint64), tensor(uint64),
4347 tensor(uint8) tensor(uint8)
4448 ): ):
4549 Constrain input and output types to all tensor types. Constrain input and output types to all tensor types (including
50 bfloat).
.. _l-onnx-op-where-9: Where - 9 ========= **Version** * **name**: `Where (GitHub) `_ * **domain**: **main** * **since_version**: **9** * **function**: False * **support_level**: SupportType.COMMON * **shape inference**: True This version of the operator has been available **since version 9**. **Summary** Return elements, either from X or Y, depending on condition. Where behaves like [numpy.where](https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html) with three parameters. This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check `Broadcasting in ONNX `_. **Inputs** * **condition** (heterogeneous) - **B**: When True (nonzero), yield X, otherwise yield Y * **X** (heterogeneous) - **T**: values selected at indices where condition is True * **Y** (heterogeneous) - **T**: values selected at indices where condition is False **Outputs** * **output** (heterogeneous) - **T**: Tensor of shape equal to the broadcasted shape of condition, X, and Y. **Type Constraints** * **B** in ( tensor(bool) ): Constrain to boolean tensors. * **T** in ( tensor(bool), tensor(complex128), tensor(complex64), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) ): Constrain input and output types to all tensor types.