一.引言

开发时同时用到了TF1与TF2,使用中发现 split 函数在V1和V2中有一些区别,记录一下。

二.TF1.x tf.string_split

1.使用

Input: 对字符串数组进行分割,默认分隔符为" ",skip_empty代表是否忽略空字符创

Output: 返回一个SparseTensor

def string_split(source, delimiter=" ", skip_empty=True):
  delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string)
  source = ops.convert_to_tensor(source, dtype=dtypes.string)

  indices, values, shape = gen_string_ops.string_split(
      source, delimiter=delimiter, skip_empty=skip_empty)
  indices.set_shape([None, 2])
  values.set_shape([None])
  shape.set_shape([2])
  return sparse_tensor.SparseTensor(indices, values, shape)

2.示例

输入 source=[['hello world], ['a b c']] 是二维数组

输出 SparseTensor(indices = [0,0; 0,1;1,0;1,1;,1,2], values=['hello', 'world', 'a', 'b', 'c'], shape=[2,3])

关于SparseTensor相关知识可以参考: SparseTensor 与 Lookup

  For example:
  N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output
  will be

  st.indices = [0, 0;
                0, 1;
                1, 0;
                1, 1;
                1, 2]
  st.shape = [2, 3]
  st.values = ['hello', 'world', 'a', 'b', 'c']

三.TF2.x tf.strings.split

1.使用

Input: 对字符串,字符串数组进行分割,增加了 maxsplit 参数

Output: 返回 RaggedTensor

def string_split_v2(input, sep=None, maxsplit=-1, name=None):
  with ops.name_scope(name, "StringSplit", [input]):
    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
        input, dtype=dtypes.string, name="input")
    if isinstance(input, ragged_tensor.RaggedTensor):
      return input.with_flat_values(
          string_split_v2(input.flat_values, sep, maxsplit))

    rank = input.shape.ndims
    if rank == 0:
      return string_split_v2(array_ops.stack([input]), sep, maxsplit)[0]
    elif rank == 1 or rank is None:
      sparse_result = string_ops.string_split_v2(
          input, sep=sep, maxsplit=maxsplit)
      return ragged_tensor.RaggedTensor.from_value_rowids(
          values=sparse_result.values,
          value_rowids=sparse_result.indices[:, 0],
          nrows=sparse_result.dense_shape[0],
          validate=False)
    else:
      return string_split_v2(
          ragged_tensor.RaggedTensor.from_tensor(input), sep, maxsplit)

2.示例

A. 分割字符

输入  string = ["a b c"]

输出 常规 Tensor

tf.Tensor([b'a' b'b' b'c'], shape=(3,), dtype=string)

B. 分割字符数组

输入  text = ["123/bbbb/cba", "456/cccc/abc"]

返回 tf.RaggedTensor,shape 为(2, None),values 为常规 Tensor,可以视作Flatten

<tf.RaggedTensor [[b'123', b'bbbb', b'cba'], [b'456', b'cccc', b'abc']]>
shape=(2, None)
values=tf.Tensor([b'123' b'bbbb' b'cba' b'456' b'cccc' b'abc'], shape=(6,), dtype=string)

四.tf.string_split 与 tf.strings.split 异同

1.相同点

对字符进行分割

2.不同点

V1-tf.string_split V2-tf.strings.split
输入 字符串数组,单字符也需要[string] 字符串 or 字符串数组
返回类型 SparseTensor 稀疏张量 RaggedTensor 不规则张量
返回形状 按分割最长的字符补充                 只保留数组长度(len, None)

更多推荐算法相关深度学习:深度学习导读专栏 

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐