Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ def get_features_dict(self):
'steps': tfds.features.Dataset({
'observation': {
k: rlu_common.float_tensor_feature(v)
for k, v in self.builder_config.observation_size.items()
for k, v in self.builder_config.observation_size.items() # pyrefly: ignore[missing-attribute]
},
'action': tfds.features.Tensor(
shape=(self.builder_config.action_size,), dtype=np.float32
shape=(self.builder_config.action_size,), dtype=np.float32 # pyrefly: ignore[missing-attribute]
),
'reward': np.float32,
'is_terminal': np.bool_,
Expand All @@ -206,7 +206,7 @@ def get_citation(self):
return _CITATION

def get_file_prefix(self):
task = self.builder_config.name
task = self.builder_config.name # pyrefly: ignore[missing-attribute]
return f'{self._INPUT_FILE_PREFIX}/{task}/train'

def num_shards(self):
Expand All @@ -215,11 +215,11 @@ def num_shards(self):
def _get_example_specs(self):
obs_features = {
f'observation/{k}': _sequence_feature(v)
for k, v in self.builder_config.observation_size.items()
for k, v in self.builder_config.observation_size.items() # pyrefly: ignore[missing-attribute]
}
return {
**obs_features,
'action': _sequence_feature(self.builder_config.action_size),
'action': _sequence_feature(self.builder_config.action_size), # pyrefly: ignore[missing-attribute]
'discount': _sequence_feature(),
'reward': _sequence_feature(),
'step_type': _sequence_feature(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class RluLocomotion(rlu_common.RLUBuilder):
_INPUT_FILE_PREFIX = 'gs://rl_unplugged/dm_locomotion_episodes/'

def get_features_dict(self):
if 'humanoid' in self.builder_config.name:
if 'humanoid' in self.builder_config.name: # pyrefly: ignore[missing-attribute]
walker_features = {
'joints_vel': rlu_common.float_tensor_feature(56),
'sensors_velocimeter': rlu_common.float_tensor_feature(3),
Expand Down Expand Up @@ -171,7 +171,7 @@ def get_features_dict(self):
action_features = tfds.features.Tensor(shape=(38,), dtype=np.float32)

return tfds.features.FeaturesDict({
'steps': tfds.features.Dataset({
'steps': tfds.features.Dataset({ # pyrefly: ignore[bad-argument-type]
'observation': {
'walker': walker_features,
},
Expand All @@ -193,7 +193,7 @@ def get_citation(self):
return _CITATION

def get_file_prefix(self):
task = self.builder_config.name
task = self.builder_config.name # pyrefly: ignore[missing-attribute]
return f'{self._INPUT_FILE_PREFIX}/{task}/train'

def num_shards(self):
Expand All @@ -203,7 +203,7 @@ def tf_example_to_step_ds(
self, tf_example: tf.train.Example
) -> Dict[str, Any]:
"""Create an episode from a TF example."""
feature_description = _feature_description(self.builder_config.name)
feature_description = _feature_description(self.builder_config.name) # pyrefly: ignore[missing-attribute]

data = tf.io.parse_single_example(tf_example, feature_description)
episode_length = tf.size(data['discount'])
Expand Down
18 changes: 9 additions & 9 deletions tensorflow_datasets/rl_unplugged/rlu_rwrl/rlu_rwrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ def tf_example_to_feature_description(
'tf_example_to_feature_description() only works under eager mode.'
)
example = example.numpy() # pytype: disable=attribute-error
example = tf.train.Example.FromString(example)
example = tf.train.Example.FromString(example) # pyrefly: ignore[bad-argument-type, bad-assignment]

ret = {}
for k, v in example.features.feature.items():
for k, v in example.features.feature.items(): # pyrefly: ignore[missing-attribute]
l = len(v.float_list.value)
if l % num_timesteps:
raise ValueError(
Expand Down Expand Up @@ -336,10 +336,10 @@ def get_citation(self):
return _CITATION

def get_file_prefix(self):
domain = self.builder_config.domain
task = self.builder_config.task
combined_challenge = self.builder_config.combined_challenge
dataset_size = self.builder_config.dataset_size
domain = self.builder_config.domain # pyrefly: ignore[missing-attribute]
task = self.builder_config.task # pyrefly: ignore[missing-attribute]
combined_challenge = self.builder_config.combined_challenge # pyrefly: ignore[missing-attribute]
dataset_size = self.builder_config.dataset_size # pyrefly: ignore[missing-attribute]
return (
f'{self._INPUT_FILE_PREFIX}/'
f'combined_challenge_{str(combined_challenge).lower()}/'
Expand All @@ -351,9 +351,9 @@ def num_shards(self):
return self._SHARDS # For testing. # type: ignore
except AttributeError:
pass
domain = self.builder_config.domain
combined_challenge = self.builder_config.combined_challenge
dataset_size = self.builder_config.dataset_size
domain = self.builder_config.domain # pyrefly: ignore[missing-attribute]
combined_challenge = self.builder_config.combined_challenge # pyrefly: ignore[missing-attribute]
dataset_size = self.builder_config.dataset_size # pyrefly: ignore[missing-attribute]
return SHARDS_MAPPING[(combined_challenge, domain, dataset_size)]

def tf_example_to_step_ds( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
Expand Down
24 changes: 12 additions & 12 deletions tensorflow_datasets/rlds/datasets/locomotion/locomotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class Locomotion(tfds.core.GeneratorBasedBuilder):
name='ant_sac_1M_single_policy_stochastic',
observation_info=tfds.features.Tensor(shape=(111,), dtype=np.float32),
action_info=tfds.features.Tensor(shape=(8,), dtype=np.float32),
reward_info=np.float32,
discount_info=np.float32,
reward_info=np.float32, # pyrefly: ignore[bad-argument-type]
discount_info=np.float32, # pyrefly: ignore[bad-argument-type]
citation=_CITATION,
homepage=_HOMEPAGE,
overall_description=_DESCRIPTION,
Expand All @@ -76,8 +76,8 @@ class Locomotion(tfds.core.GeneratorBasedBuilder):
name='hopper_sac_1M_single_policy_stochastic',
observation_info=tfds.features.Tensor(shape=(11,), dtype=np.float32),
action_info=tfds.features.Tensor(shape=(3,), dtype=np.float32),
reward_info=np.float32,
discount_info=np.float32,
reward_info=np.float32, # pyrefly: ignore[bad-argument-type]
discount_info=np.float32, # pyrefly: ignore[bad-argument-type]
citation=_CITATION,
homepage=_HOMEPAGE,
overall_description=_DESCRIPTION,
Expand All @@ -91,8 +91,8 @@ class Locomotion(tfds.core.GeneratorBasedBuilder):
name='halfcheetah_sac_1M_single_policy_stochastic',
observation_info=tfds.features.Tensor(shape=(17,), dtype=np.float32),
action_info=tfds.features.Tensor(shape=(6,), dtype=np.float32),
reward_info=np.float32,
discount_info=np.float32,
reward_info=np.float32, # pyrefly: ignore[bad-argument-type]
discount_info=np.float32, # pyrefly: ignore[bad-argument-type]
citation=_CITATION,
homepage=_HOMEPAGE,
overall_description=_DESCRIPTION,
Expand All @@ -106,8 +106,8 @@ class Locomotion(tfds.core.GeneratorBasedBuilder):
name='walker2d_sac_1M_single_policy_stochastic',
observation_info=tfds.features.Tensor(shape=(17,), dtype=np.float32),
action_info=tfds.features.Tensor(shape=(6,), dtype=np.float32),
reward_info=np.float32,
discount_info=np.float32,
reward_info=np.float32, # pyrefly: ignore[bad-argument-type]
discount_info=np.float32, # pyrefly: ignore[bad-argument-type]
citation=_CITATION,
homepage=_HOMEPAGE,
overall_description=_DESCRIPTION,
Expand All @@ -121,8 +121,8 @@ class Locomotion(tfds.core.GeneratorBasedBuilder):
name='humanoid_sac_15M_single_policy_stochastic',
observation_info=tfds.features.Tensor(shape=(376,), dtype=np.float32),
action_info=tfds.features.Tensor(shape=(17,), dtype=np.float32),
reward_info=np.float32,
discount_info=np.float32,
reward_info=np.float32, # pyrefly: ignore[bad-argument-type]
discount_info=np.float32, # pyrefly: ignore[bad-argument-type]
citation=_CITATION,
homepage=_HOMEPAGE,
overall_description=_DESCRIPTION,
Expand All @@ -137,13 +137,13 @@ class Locomotion(tfds.core.GeneratorBasedBuilder):

def _info(self) -> tfds.core.DatasetInfo:
"""Returns the dataset metadata."""
return rlds_base.build_info(self.builder_config, self)
return rlds_base.build_info(self.builder_config, self) # pyrefly: ignore[bad-argument-type]

def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Returns SplitGenerators."""
path = dl_manager.download_and_extract(
{
'file_path': self._DATA_PATHS[self.builder_config.name],
'file_path': self._DATA_PATHS[self.builder_config.name], # pyrefly: ignore[missing-attribute]
}
)
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,13 @@ class RobosuitePandaPickPlaceCan(tfds.core.GeneratorBasedBuilder):

def _info(self) -> tfds.core.DatasetInfo:
"""Returns the dataset metadata."""
return rlds_base.build_info(self.builder_config, self)
return rlds_base.build_info(self.builder_config, self) # pyrefly: ignore[bad-argument-type]

def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Returns SplitGenerators."""
path = dl_manager.download_and_extract(
{
'file_path': self._DATA_PATHS[self.builder_config.name],
'file_path': self._DATA_PATHS[self.builder_config.name], # pyrefly: ignore[missing-attribute]
}
)
return {
Expand Down
8 changes: 4 additions & 4 deletions tensorflow_datasets/rlds/rlds_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def build_info(
**step_metadata,
}
if ds_config.observation_info:
step_info['observation'] = ds_config.observation_info
step_info['observation'] = ds_config.observation_info # pyrefly: ignore[unsupported-operation]
if ds_config.action_info:
step_info['action'] = ds_config.action_info
if ds_config.reward_info:
Expand All @@ -95,13 +95,13 @@ def build_info(
builder=builder,
description=ds_config.overall_description,
features=tfds.features.FeaturesDict({
'steps': tfds.features.Dataset(step_info),
'steps': tfds.features.Dataset(step_info), # pyrefly: ignore[bad-argument-type]
**episode_metadata,
}),
supervised_keys=ds_config.supervised_keys,
supervised_keys=ds_config.supervised_keys, # pyrefly: ignore[bad-argument-type]
homepage=ds_config.homepage,
citation=ds_config.citation,
metadata=ds_metadata,
metadata=ds_metadata, # pyrefly: ignore[bad-argument-type]
)


Expand Down
43 changes: 22 additions & 21 deletions tensorflow_datasets/robomimic/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def make_builder_configs(dataset: DataSource):
"""Creates the PH build configs."""
configs = []
for task, details in TASKS.items():
if dataset in details['datasets']:
if dataset in details['datasets']: # pyrefly: ignore[not-iterable]
for observation_type in [ObservationType.IMAGE, ObservationType.LOW_DIM]:
# pytype: disable=wrong-keyword-args
configs.append(
Expand Down Expand Up @@ -216,7 +216,7 @@ class RobomimicBuilder(tfds.core.GeneratorBasedBuilder, skip_registration=True):
"""DatasetBuilder for robomimic datasets."""

VERSION: tfds.core.Version
RELEASE_NOTES: Dict[str, str]
RELEASE_NOTES: Dict[str, str] # pyrefly: ignore[bad-override]
BUILDER_CONFIGS: List[tfds.core.BuilderConfig]
DATASET_NAME: str
DATASET_FILE_EXTENSION: str = ''
Expand All @@ -231,13 +231,13 @@ def _info(self) -> tfds.core.DatasetInfo:
)

def _get_features(self) -> tfds.features.FeaturesDict:
obs_dim = TASKS[self.builder_config.task]['object']
states_dim = TASKS[self.builder_config.task]['states']
action_size = TASKS[self.builder_config.task]['action_size']
obs_dim = TASKS[self.builder_config.task]['object'] # pyrefly: ignore[missing-attribute]
states_dim = TASKS[self.builder_config.task]['states'] # pyrefly: ignore[missing-attribute]
action_size = TASKS[self.builder_config.task]['action_size'] # pyrefly: ignore[missing-attribute]

observation = {
'object': tensor_feature(
obs_dim,
obs_dim, # pyrefly: ignore[bad-argument-type]
),
'robot0_eef_pos': tensor_feature(3, doc='End-effector position'),
'robot0_eef_quat': tensor_feature(4, doc='End-effector orientation'),
Expand All @@ -254,7 +254,7 @@ def _get_features(self) -> tfds.features.FeaturesDict:
'robot0_joint_pos_sin': tensor_feature(7),
'robot0_joint_vel': tensor_feature(7, doc='7DOF joint velocities'),
}
if self.builder_config.task == Task.TRANSPORT:
if self.builder_config.task == Task.TRANSPORT: # pyrefly: ignore[missing-attribute]
observation['robot1_eef_pos'] = tensor_feature(
3, doc='End-effector position'
)
Expand Down Expand Up @@ -282,34 +282,34 @@ def _get_features(self) -> tfds.features.FeaturesDict:
7, doc='7DOF joint velocities'
)

if self.builder_config.filename == ObservationType.IMAGE:
if self.builder_config.filename == ObservationType.IMAGE: # pyrefly: ignore[missing-attribute]
if self.builder_config.task == Task.TOOL_HANG:
observation['robot0_eye_in_hand_image'] = image_feature(240)
observation['sideview_image'] = image_feature(240)
observation['robot0_eye_in_hand_image'] = image_feature(240) # pyrefly: ignore[bad-assignment]
observation['sideview_image'] = image_feature(240) # pyrefly: ignore[bad-assignment]
elif self.builder_config.task == Task.TRANSPORT:
observation['robot0_eye_in_hand_image'] = image_feature(84)
observation['robot1_eye_in_hand_image'] = image_feature(84)
observation['shouldercamera0_image'] = image_feature(84)
observation['shouldercamera1_image'] = image_feature(84)
observation['robot0_eye_in_hand_image'] = image_feature(84) # pyrefly: ignore[bad-assignment]
observation['robot1_eye_in_hand_image'] = image_feature(84) # pyrefly: ignore[bad-assignment]
observation['shouldercamera0_image'] = image_feature(84) # pyrefly: ignore[bad-assignment]
observation['shouldercamera1_image'] = image_feature(84) # pyrefly: ignore[bad-assignment]
else:
observation['agentview_image'] = image_feature(84)
observation['robot0_eye_in_hand_image'] = image_feature(84)
observation['agentview_image'] = image_feature(84) # pyrefly: ignore[bad-assignment]
observation['robot0_eye_in_hand_image'] = image_feature(84) # pyrefly: ignore[bad-assignment]

# metadata depends on the quality type
metadata = self._get_metadata()

features = tfds.features.FeaturesDict({
'horizon': np.int32,
'episode_id': np.str_,
'steps': tfds.features.Dataset({
'action': tensor_feature(action_size),
'steps': tfds.features.Dataset({ # pyrefly: ignore[bad-argument-type]
'action': tensor_feature(action_size), # pyrefly: ignore[bad-argument-type]
'observation': observation,
'reward': np.float64,
'is_first': np.bool_,
'is_last': np.bool_,
'is_terminal': np.bool_,
'discount': np.int32,
'states': tensor_feature(states_dim),
'states': tensor_feature(states_dim), # pyrefly: ignore[bad-argument-type]
}),
**metadata,
})
Expand All @@ -325,6 +325,7 @@ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
# in the sparse rewards, for consistency with the other datasets.
ext = self.DATASET_FILE_EXTENSION
filepath = (
# pyrefly: ignore[missing-attribute]
'http://downloads.cs.stanford.edu/downloads/rt_benchmark/'
f'{self.builder_config.task}/{self.builder_config.dataset}/'
f'{self.builder_config.filename}{ext}.hdf5'
Expand All @@ -351,14 +352,14 @@ def _generate_examples(self, path):
for key in data:
yield key, {
'steps': build_episode(data[key]),
'horizon': self.builder_config.horizon,
'horizon': self.builder_config.horizon, # pyrefly: ignore[missing-attribute]
'episode_id': key,
**episode_metadata(mask, key),
}
else:
for key in data:
yield key, {
'steps': build_episode(data[key]),
'horizon': self.builder_config.horizon,
'horizon': self.builder_config.horizon, # pyrefly: ignore[missing-attribute]
'episode_id': key,
}
8 changes: 4 additions & 4 deletions tensorflow_datasets/robotics/mt_opt/mt_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,15 @@ def _info(self) -> tfds.core.DatasetInfo:
return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
features=_name_to_features(self.builder_config.name),
features=_name_to_features(self.builder_config.name), # pyrefly: ignore[missing-attribute]
supervised_keys=None,
homepage='https://karolhausman.github.io/mt-opt/',
citation=_CITATION,
)

def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Returns SplitGenerators."""
ds_name = self.builder_config.name
ds_name = self.builder_config.name # pyrefly: ignore[missing-attribute]
splits = {}
for split, shards in _NAME_TO_SPLITS[ds_name].items():
paths = {
Expand All @@ -184,15 +184,15 @@ def _generate_examples_one_file(
# Dataset of tf.Examples containing full episodes.
example_ds = tf.data.TFRecordDataset(filenames=str(path))

example_features = _name_to_features_encode(self.builder_config.name)
example_features = _name_to_features_encode(self.builder_config.name) # pyrefly: ignore[missing-attribute]
example_specs = example_features.get_serialized_info()
parser = tfds.core.example_parser.ExampleParser(example_specs)

parsed_examples = example_ds.map(parser.parse_example)
decoded_examples = parsed_examples.map(example_features.decode_example)

for index, example in enumerate(tfds.as_numpy(decoded_examples)):
if self.builder_config.name == 'rlds':
if self.builder_config.name == 'rlds': # pyrefly: ignore[missing-attribute]
id_key = 'episode_id'
else:
id_key = 'task_code'
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_datasets/scripts/replace_fake_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ def rewrite_tar(root_dir, tar_filepath):
extension = ''

# Extraction of .tar file
with tarfile.open(tar_filepath, 'r' + extension) as tar:
with tarfile.open(tar_filepath, 'r' + extension) as tar: # pyrefly: ignore[no-matching-overload]
tar.extractall(path=temp_dir)

rewrite_dir(temp_dir) # Recursively compress the archive content

# Convert back into tar file
with tarfile.open(tar_filepath, 'w' + extension) as tar:
with tarfile.open(tar_filepath, 'w' + extension) as tar: # pyrefly: ignore[no-matching-overload]
tar.add(temp_dir, arcname='', recursive=True)


Expand Down
Loading
Loading