Skip to content

Commit bed365c

Browse files
authored
Merge pull request #51 from jjshoots/pole_env
Pole Waypoints Env
2 parents be0a846 + 68fcd83 commit bed365c

33 files changed

+182
-176
lines changed

PyFlyt/core/abstractions/boosters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def get_states(self) -> np.ndarray:
141141
- (b0, b1, ..., bn) represent the remaining fuel ratio
142142
- (c0, c1, ..., cn) represent the current throttle state
143143
144-
Returns
144+
Returns:
145145
-------
146146
np.ndarray: A (3 * num_boosters, ) array
147147

PyFlyt/core/abstractions/camera.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(
100100
def view_mat(self) -> np.ndarray:
101101
"""Generates the view matrix for the camera depending on the current orientation and implicit parameters.
102102
103-
Returns
103+
Returns:
104104
-------
105105
np.ndarray: view matrix.
106106
@@ -161,7 +161,7 @@ def physics_update(self):
161161
def capture_image(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
162162
"""Captures the 3 relevant images from the camera.
163163
164-
Returns
164+
Returns:
165165
-------
166166
tuple[np.ndarray, np.ndarray, np.ndarray]: rgbaImg, depthImg, segImg
167167

PyFlyt/core/abstractions/gimbals.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def reset(self):
123123
def get_states(self) -> np.ndarray:
124124
"""Gets the current state of the components.
125125
126-
Returns
126+
Returns:
127127
-------
128128
np.ndarray: a (2 * num_gimbals, ) array where every pair of values represents the current state of the gimbal
129129

PyFlyt/core/abstractions/lifting_surfaces.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def reset(self):
4747
def get_states(self) -> np.ndarray:
4848
"""Gets the current state of the components.
4949
50-
Returns
50+
Returns:
5151
-------
5252
np.ndarray: a (num_surfaces, ) array representing the actuation state for each surface
5353
@@ -254,7 +254,7 @@ def reset(self):
254254
def get_states(self) -> float:
255255
"""Gets the current state of the components.
256256
257-
Returns
257+
Returns:
258258
-------
259259
float: the level of deflection of the surface.
260260

PyFlyt/core/abstractions/motors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def reset(self) -> None:
9898
def get_states(self) -> np.ndarray:
9999
"""Gets the current state of the components.
100100
101-
Returns
101+
Returns:
102102
-------
103103
np.ndarray: an (num_motors, ) array for the current throttle level of each motor
104104

PyFlyt/core/aviary.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, message: str) -> None:
3838
def __str__(self) -> str:
3939
"""__str__.
4040
41-
Returns
41+
Returns:
4242
-------
4343
str:
4444
@@ -380,7 +380,7 @@ def all_states(self) -> list[np.ndarray]:
380380
381381
This function is not very optimized, if you want the state of a single drone, do `state(i)`.
382382
383-
Returns
383+
Returns:
384384
-------
385385
np.ndarray: list of states
386386
@@ -399,7 +399,7 @@ def all_aux_states(self) -> list[np.ndarray]:
399399
400400
This function is not very optimized, if you want the aux state of a single drone, do `aux_state(i)`.
401401
402-
Returns
402+
Returns:
403403
-------
404404
np.ndarray: list of auxiliary states
405405

PyFlyt/core/drones/quadx.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
use_gimbal: bool = False,
3333
camera_angle_degrees: int = 20,
3434
camera_FOV_degrees: int = 90,
35+
camera_position_offset: np.ndarray = np.array([0.0, 0.0, 0.0]),
3536
camera_resolution: tuple[int, int] = (128, 128),
3637
camera_fps: None | int = None,
3738
):
@@ -51,6 +52,7 @@ def __init__(
5152
use_gimbal (bool): use_gimbal
5253
camera_angle_degrees (int): camera_angle_degrees
5354
camera_FOV_degrees (int): camera_FOV_degrees
55+
camera_position_offset (np.ndarray): offset position of the camera
5456
camera_resolution (tuple[int, int]): camera_resolution
5557
camera_fps (None | int): camera_fps
5658
@@ -205,6 +207,7 @@ def __init__(
205207
camera_FOV_degrees=camera_FOV_degrees,
206208
camera_angle_degrees=camera_angle_degrees,
207209
camera_resolution=camera_resolution,
210+
camera_position_offset=camera_position_offset,
208211
)
209212

210213
# compute camera fps parameters

PyFlyt/gym_envs/fixedwing_envs/fixedwing_waypoints_env.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def reset(
116116
super().begin_reset(seed, options)
117117
self.waypoints.reset(self.env, self.np_random)
118118
self.info["num_targets_reached"] = 0
119-
self.distance_to_immediate = np.inf
120119
super().end_reset()
121120

122121
return self.state, self.info
@@ -165,12 +164,9 @@ def compute_state(self) -> None:
165164
axis=-1,
166165
)
167166

168-
new_state["target_deltas"] = self.waypoints.distance_to_target(
167+
new_state["target_deltas"] = self.waypoints.distance_to_targets(
169168
ang_pos, lin_pos, quaternion
170169
)
171-
self.distance_to_immediate = float(
172-
np.linalg.norm(new_state["target_deltas"][0])
173-
)
174170

175171
self.state: dict[Literal["attitude", "target_deltas"], np.ndarray] = new_state
176172

@@ -180,17 +176,17 @@ def compute_term_trunc_reward(self) -> None:
180176

181177
# bonus reward if we are not sparse
182178
if not self.sparse_reward:
183-
self.reward += max(3.0 * self.waypoints.progress_to_target(), 0.0)
184-
self.reward += 1.0 / self.distance_to_immediate
179+
self.reward += max(3.0 * self.waypoints.progress_to_next_target, 0.0)
180+
self.reward += 1.0 / self.waypoints.distance_to_next_target
185181

186182
# target reached
187-
if self.waypoints.target_reached():
183+
if self.waypoints.target_reached:
188184
self.reward = 100.0
189185

190186
# advance the targets
191187
self.waypoints.advance_targets()
192188

193189
# update infos and dones
194-
self.truncation |= self.waypoints.all_targets_reached()
195-
self.info["env_complete"] = self.waypoints.all_targets_reached()
196-
self.info["num_targets_reached"] = self.waypoints.num_targets_reached()
190+
self.truncation |= self.waypoints.all_targets_reached
191+
self.info["env_complete"] = self.waypoints.all_targets_reached
192+
self.info["num_targets_reached"] = self.waypoints.num_targets_reached

PyFlyt/gym_envs/quadx_envs/quadx_base_env.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,30 @@ def __init__(
7575
self.auxiliary_space = spaces.Box(
7676
low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64
7777
)
78-
angular_rate_limit = np.pi
78+
79+
# define the action space
80+
xyz_limit = np.pi
7981
thrust_limit = 0.8
80-
high = np.array(
81-
[
82-
angular_rate_limit,
83-
angular_rate_limit,
84-
angular_rate_limit,
85-
thrust_limit,
86-
]
87-
)
88-
low = np.array(
89-
[
90-
-angular_rate_limit,
91-
-angular_rate_limit,
92-
-angular_rate_limit,
93-
0.0,
94-
]
95-
)
82+
if flight_mode == -1:
83+
high = np.ones((4,)) * thrust_limit
84+
low = np.zeros((4,))
85+
else:
86+
high = np.array(
87+
[
88+
xyz_limit,
89+
xyz_limit,
90+
xyz_limit,
91+
thrust_limit,
92+
]
93+
)
94+
low = np.array(
95+
[
96+
-xyz_limit,
97+
-xyz_limit,
98+
-xyz_limit,
99+
0.0,
100+
]
101+
)
96102
self.action_space = spaces.Box(low=low, high=high, dtype=np.float64)
97103

98104
# the whole implicit state space = attitude + previous action + auxiliary information

PyFlyt/gym_envs/quadx_envs/quadx_pole_balance_env.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class QuadXPoleBalanceEnv(QuadXBaseEnv):
1414
"""Simple Hover Environment with the additional goal of keeping a pole upright.
1515
16-
Actions are vp, vq, vr, T, ie: angular rates and thrust.
16+
Actions are direct motor PWM commands because any underlying controller introduces too much control latency.
1717
The target is to not crash and not let the pole hit the ground for the longest time possible.
1818
1919
Args:
@@ -32,7 +32,7 @@ class QuadXPoleBalanceEnv(QuadXBaseEnv):
3232
def __init__(
3333
self,
3434
sparse_reward: bool = False,
35-
flight_mode: int = 0,
35+
flight_mode: int = -1,
3636
flight_dome_size: float = 3.0,
3737
max_duration_seconds: float = 20.0,
3838
angle_representation: Literal["euler", "quaternion"] = "quaternion",
@@ -45,12 +45,12 @@ def __init__(
4545
Args:
4646
----
4747
sparse_reward (bool): whether to use sparse rewards or not.
48-
flight_mode (int): the flight mode of the UAV
48+
flight_mode (int): the flight mode of the UAV.
4949
flight_dome_size (float): size of the allowable flying area.
5050
max_duration_seconds (float): maximum simulation time of the environment.
5151
angle_representation (Literal["euler", "quaternion"]): can be "euler" or "quaternion".
5252
agent_hz (int): looprate of the agent to environment interaction.
53-
render_mode (None | Literal["human", "rgb_array"]): render_mode
53+
render_mode (None | Literal["human", "rgb_array"]): render_mode.
5454
render_resolution (tuple[int, int]): render_resolution.
5555
5656
"""
@@ -94,7 +94,10 @@ def reset(
9494
super().begin_reset(
9595
seed,
9696
options,
97-
drone_options={"drone_model": "primitive_drone"},
97+
drone_options={
98+
"drone_model": "primitive_drone",
99+
"camera_position_offset": np.array([-3.0, 0.0, 1.0]),
100+
},
98101
)
99102
self.pole.reset(p=self.env, start_location=np.array([0.0, 0.0, 1.55]))
100103
super().end_reset(seed, options)

PyFlyt/gym_envs/quadx_envs/quadx_pole_waypoints_env.py

+24-23
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class QuadXPoleWaypointsEnv(QuadXBaseEnv):
1515
"""QuadX Pole Waypoints Environment.
1616
17-
Actions are vp, vq, vr, T, ie: angular rates and thrust.
17+
Actions are direct motor PWM commands because any underlying controller introduces too much control latency.
1818
The target is to get to a set of `[x, y, z]` waypoints in space without dropping the pole.
1919
2020
Args:
@@ -37,11 +37,11 @@ def __init__(
3737
sparse_reward: bool = False,
3838
num_targets: int = 4,
3939
goal_reach_distance: float = 0.2,
40-
flight_mode: int = 0,
40+
flight_mode: int = -1,
4141
flight_dome_size: float = 10.0,
42-
max_duration_seconds: float = 60.0,
42+
max_duration_seconds: float = 20.0,
4343
angle_representation: Literal["euler", "quaternion"] = "quaternion",
44-
agent_hz: int = 30,
44+
agent_hz: int = 40,
4545
render_mode: None | Literal["human", "rgb_array"] = None,
4646
render_resolution: tuple[int, int] = (480, 480),
4747
):
@@ -57,12 +57,11 @@ def __init__(
5757
max_duration_seconds (float): maximum simulation time of the environment.
5858
angle_representation (Literal["euler", "quaternion"]): can be "euler" or "quaternion".
5959
agent_hz (int): looprate of the agent to environment interaction.
60-
render_mode (None | Literal["human", "rgb_array"]): render_mode
60+
render_mode (None | Literal["human", "rgb_array"]): render_mode.
6161
render_resolution (tuple[int, int]): render_resolution.
6262
6363
"""
6464
super().__init__(
65-
start_pos=np.array([[0.0, 0.0, 1.0]]),
6665
flight_mode=flight_mode,
6766
flight_dome_size=flight_dome_size,
6867
max_duration_seconds=max_duration_seconds,
@@ -126,7 +125,12 @@ def reset(
126125
127126
"""
128127
super().begin_reset(
129-
seed, options, drone_options={"drone_model": "primitive_drone"}
128+
seed,
129+
options,
130+
drone_options={
131+
"drone_model": "primitive_drone",
132+
"camera_position_offset": np.array([-3.0, 0.0, 1.0]),
133+
},
130134
)
131135

132136
# spawn in a pole
@@ -135,7 +139,6 @@ def reset(
135139
# init some other metadata
136140
self.waypoints.reset(self.env, self.np_random)
137141
self.info["num_targets_reached"] = 0
138-
self.distance_to_immediate = np.inf
139142

140143
super().end_reset()
141144

@@ -162,10 +165,10 @@ def compute_state(self) -> None:
162165
"""
163166
# compute attitude of self
164167
ang_vel, ang_pos, lin_vel, lin_pos, quaternion = super().compute_attitude()
168+
aux_state = super().compute_auxiliary()
165169
rotation = (
166170
np.array(self.env.getMatrixFromQuaternion(quaternion)).reshape(3, 3).T
167171
)
168-
aux_state = super().compute_auxiliary()
169172

170173
# compute the pole's states
171174
(
@@ -210,36 +213,34 @@ def compute_state(self) -> None:
210213
pole_bot_pos,
211214
pole_top_vel,
212215
pole_bot_vel,
213-
]
216+
],
217+
axis=-1,
214218
)
215219

216-
new_state["target_deltas"] = self.waypoints.distance_to_target(
220+
new_state["target_deltas"] = self.waypoints.distance_to_targets(
217221
ang_pos, lin_pos, quaternion
218222
)
219-
self.distance_to_immediate = float(
220-
np.linalg.norm(new_state["target_deltas"][0])
221-
)
222223

223224
self.state: dict[Literal["attitude", "target_deltas"], np.ndarray] = new_state
224225

225226
def compute_term_trunc_reward(self) -> None:
226-
"""Computes the termination, trunction, and reward of the current timestep."""
227+
"""Computes the termination, truncation, and reward of the current timestep."""
227228
super().compute_base_term_trunc_reward()
228229

229230
# bonus reward if we are not sparse
230231
if not self.sparse_reward:
231-
self.reward += max(3.0 * self.waypoints.progress_to_target(), 0.0)
232-
self.reward += 0.1 / self.distance_to_immediate
233-
self.reward -= self.pole.leaningness
232+
self.reward += max(15.0 * self.waypoints.progress_to_next_target, 0.0)
233+
self.reward += 0.5 / self.waypoints.distance_to_next_target
234+
self.reward += (0.5 - self.pole.leaningness)
234235

235236
# target reached
236-
if self.waypoints.target_reached():
237-
self.reward = 100.0
237+
if self.waypoints.target_reached:
238+
self.reward = 300.0
238239

239240
# advance the targets
240241
self.waypoints.advance_targets()
241242

242243
# update infos and dones
243-
self.truncation |= self.waypoints.all_targets_reached()
244-
self.info["env_complete"] = self.waypoints.all_targets_reached()
245-
self.info["num_targets_reached"] = self.waypoints.num_targets_reached()
244+
self.truncation |= self.waypoints.all_targets_reached
245+
self.info["env_complete"] = self.waypoints.all_targets_reached
246+
self.info["num_targets_reached"] = self.waypoints.num_targets_reached

0 commit comments

Comments
 (0)