A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
Key metrics and engagement data
Repository has been active for N/A
Looks like this repository is a hidden gem!
No stargazers yet. Why not be the first to give it a star?
Check back soon, we will update it in background!
⭐0
Want deeper insights? Explore GitObs.com
Documentation | TensorDict | Features | Examples, tutorials and demos | Citation | Installation | Asking a question | Contributing
TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch.
TorchRL now includes a comprehensive LLM API for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
History
class for multi-turn dialogue with automatic chat template detectionThe LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the complete documentation and GRPO implementation example to get started!
python1from torchrl.envs.llm import ChatEnv2from torchrl.modules.llm import TransformersWrapper3from torchrl.objectives.llm import GRPOLoss4from torchrl.collectors.llm import LLMCollector56# Create environment with Python tool execution7env = ChatEnv(8 tokenizer=tokenizer,9 system_prompt="You are an assistant that can execute Python code.",10 batch_size=[1]11).append_transform(PythonInterpreter())1213# Wrap your language model14llm = TransformersWrapper(15 model=model,16 tokenizer=tokenizer,17 input_mode="history"18)1920# Set up GRPO training21loss_fn = GRPOLoss(llm, critic, gamma=0.99)22collector = LLMCollector(env, llm, frames_per_batch=100)2324# Training loop25for data in collector:26 loss = loss_fn(data)27 loss.backward()28 optimizer.step()
Read the full paper for a more curated description of the library.
Check our Getting Started tutorials for quickly ramp up with the basic features of the library!
The TorchRL documentation can be found here. It contains tutorials and the API reference.
TorchRL also provides a RL knowledge base to help you debug your code, or simply learn the basics of RL. Check it out here.
We have some introductory videos for you to get to know the library better, check them out:
TorchRL being domain-agnostic, you can use it across many different fields. Here are a few examples:
TensorDict
RL algorithms are very heterogeneous, and it can be hard to recycle a codebase
across settings (e.g. from online to offline, from state-based to pixel-based
learning).
TorchRL solves this problem through TensorDict
,
a convenient data structure(1) that can be used to streamline one's
RL codebase.
With this tool, one can write a complete PPO training script in less than 100
lines of code!
python1import torch2from tensordict.nn import TensorDictModule3from tensordict.nn.distributions import NormalParamExtractor4from torch import nn56from torchrl.collectors import SyncDataCollector7from torchrl.data.replay_buffers import TensorDictReplayBuffer, \8 LazyTensorStorage, SamplerWithoutReplacement9from torchrl.envs.libs.gym import GymEnv10from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal11from torchrl.objectives import ClipPPOLoss12from torchrl.objectives.value import GAE1314env = GymEnv("Pendulum-v1")15model = TensorDictModule(16 nn.Sequential(17 nn.Linear(3, 128), nn.Tanh(),18 nn.Linear(128, 128), nn.Tanh(),19 nn.Linear(128, 128), nn.Tanh(),20 nn.Linear(128, 2),21 NormalParamExtractor()22 ),23 in_keys=["observation"],24 out_keys=["loc", "scale"]25)26critic = ValueOperator(27 nn.Sequential(28 nn.Linear(3, 128), nn.Tanh(),29 nn.Linear(128, 128), nn.Tanh(),30 nn.Linear(128, 128), nn.Tanh(),31 nn.Linear(128, 1),32 ),33 in_keys=["observation"],34)35actor = ProbabilisticActor(36 model,37 in_keys=["loc", "scale"],38 distribution_class=TanhNormal,39 distribution_kwargs={"low": -1.0, "high": 1.0},40 return_log_prob=True41 )42buffer = TensorDictReplayBuffer(43 storage=LazyTensorStorage(1000),44 sampler=SamplerWithoutReplacement(),45 batch_size=50,46 )47collector = SyncDataCollector(48 env,49 actor,50 frames_per_batch=1000,51 total_frames=1_000_000,52)53loss_fn = ClipPPOLoss(actor, critic)54adv_fn = GAE(value_network=critic, average_gae=True, gamma=0.99, lmbda=0.95)55optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)5657for data in collector: # collect data58 for epoch in range(10):59 adv_fn(data) # compute advantage60 buffer.extend(data)61 for sample in buffer: # consume data62 loss_vals = loss_fn(sample)63 loss_val = sum(64 value for key, value in loss_vals.items() if65 key.startswith("loss")66 )67 loss_val.backward()68 optim.step()69 optim.zero_grad()70 print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")
Here is an example of how the environment API relies on tensordict to carry data from one function to another during a rollout execution:
TensorDict
makes it easy to re-use pieces of code across environments, models and
algorithms.
For instance, here's how to code a rollout in TorchRL:
diff1- obs, done = env.reset()2+ tensordict = env.reset()3policy = SafeModule(4 model,5 in_keys=["observation_pixels", "observation_vector"],6 out_keys=["action"],7)8out = []9for i in range(n_steps):10- action, log_prob = policy(obs)11- next_obs, reward, done, info = env.step(action)12- out.append((obs, next_obs, action, log_prob, reward, done))13- obs = next_obs14+ tensordict = policy(tensordict)15+ tensordict = env.step(tensordict)16+ out.append(tensordict)17+ tensordict = step_mdp(tensordict) # renames next_observation_* keys to observation_*18- obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]19+ out = torch.stack(out, 0) # TensorDict supports multiple tensor operations
Using this, TorchRL abstracts away the input / output signatures of the modules, env, collectors, replay buffers and losses of the library, allowing all primitives to be easily recycled across settings.
Here's another example of an off-policy training loop in TorchRL (assuming that a data collector, a replay buffer, a loss and an optimizer have been instantiated):
diff1- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):2+ for i, tensordict in enumerate(collector):3- replay_buffer.add((obs, next_obs, action, log_prob, reward, done))4+ replay_buffer.add(tensordict)5 for j in range(num_optim_steps):6- obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)7- loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)8+ tensordict = replay_buffer.sample(batch_size)9+ loss = loss_fn(tensordict)10 loss.backward()11 optim.step()12 optim.zero_grad()
This training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
TensorDict supports multiple tensor operations on its device and shape (the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):
python1# stack and cat2tensordict = torch.stack(list_of_tensordicts, 0)3tensordict = torch.cat(list_of_tensordicts, 0)4# reshape5tensordict = tensordict.view(-1)6tensordict = tensordict.permute(0, 2, 1)7tensordict = tensordict.unsqueeze(-1)8tensordict = tensordict.squeeze(-1)9# indexing10tensordict = tensordict[:2]11tensordict[:, 2] = sub_tensordict12# device and memory location13tensordict.cuda()14tensordict.to("cuda:1")15tensordict.share_memory_()
TensorDict comes with a dedicated tensordict.nn
module that contains everything you might need to write your model with it.
And it is functorch
and torch.compile
compatible!
diff1transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)2+ td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])3src = torch.rand((10, 32, 512))4tgt = torch.rand((20, 32, 512))5+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])6- out = transformer_model(src, tgt)7+ td_module(tensordict)8+ out = tensordict["out"]
The TensorDictSequential
class allows to branch sequences of nn.Module
instances in a highly modular way.
For instance, here is an implementation of a transformer using the encoder and decoder blocks:
python1encoder_module = TransformerEncoder(...)2encoder = TensorDictSequential(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])3decoder_module = TransformerDecoder(...)4decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])5transformer = TensorDictSequential(encoder, decoder)6assert transformer.in_keys == ["src", "src_mask", "tgt"]7assert transformer.out_keys == ["memory", "output"]
TensorDictSequential
allows to isolate subgraphs by querying a set of desired input / output keys:
python1transformer.select_subsequence(out_keys=["memory"]) # returns the encoder2transformer.select_subsequence(in_keys=["tgt", "memory"]) # returns the decoder
Check TensorDict tutorials to learn more!
A common interface for environments which supports common libraries (OpenAI gym, deepmind control lab, etc.)(1) and state-less execution (e.g. Model-based environments). The batched environments containers allow parallel execution(2). A common PyTorch-first class of tensor-specification class is also provided. TorchRL's environments API is simple but stringent and specific. Check the documentation and tutorial to learn more!
python1env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)2env_parallel = ParallelEnv(4, env_make) # creates 4 envs in parallel3tensordict = env_parallel.rollout(max_steps=20, policy=None) # random rollout (no policy given)4assert tensordict.shape == [4, 20] # 4 envs, 20 steps rollout5env_parallel.action_spec.is_in(tensordict["action"]) # spec check returns True
multiprocess and distributed data collectors(2) that work synchronously or asynchronously. Through the use of TensorDict, TorchRL's training loops are made very similar to regular training loops in supervised learning (although the "dataloader" -- read data collector -- is modified on-the-fly):
python1env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)2collector = MultiaSyncDataCollector(3 [env_make, env_make],4 policy=policy,5 devices=["cuda:0", "cuda:0"],6 total_frames=10000,7 frames_per_batch=50,8 ...9)10for i, tensordict_data in enumerate(collector):11 loss = loss_module(tensordict_data)12 loss.backward()13 optim.step()14 optim.zero_grad()15 collector.update_policy_weights_()
Check our distributed collector examples to learn more about ultra-fast data collection with TorchRL.
efficient(2) and generic(1) replay buffers with modularized storage:
python1storage = LazyMemmapStorage( # memory-mapped (physical) storage2 cfg.buffer_size,3 scratch_dir="/tmp/"4)5buffer = TensorDictPrioritizedReplayBuffer(6 alpha=0.7,7 beta=0.5,8 collate_fn=lambda x: x,9 pin_memory=device != torch.device("cpu"),10 prefetch=10, # multi-threaded sampling11 storage=storage12)
Replay buffers are also offered as wrappers around common datasets for offline RL:
python1from torchrl.data.replay_buffers import SamplerWithoutReplacement2from torchrl.data.datasets.d4rl import D4RLExperienceReplay3data = D4RLExperienceReplay(4 "maze2d-open-v0",5 split_trajs=True,6 batch_size=128,7 sampler=SamplerWithoutReplacement(drop_last=True),8)9for sample in data: # or alternatively sample = data.sample()10 fun(sample)
cross-library environment transforms(1), executed on device and in a vectorized fashion(2), which process and prepare the data coming out of the environments to be used by the agent:
python1env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)2env_base = ParallelEnv(4, env_make, device="cuda:0") # creates 4 envs in parallel3env = TransformedEnv(4 env_base,5 Compose(6 ToTensorImage(),7 ObservationNorm(loc=0.5, scale=1.0)), # executes the transforms once and on device8)9tensordict = env.reset()10assert tensordict.device == torch.device("cuda:0")
Other transforms include: reward scaling (RewardScaling
), shape operations (concatenation of tensors, unsqueezing etc.), concatenation of
successive operations (CatFrames
), resizing (Resize
) and many more.
Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes it easy to add and remove them at will:
python1env.insert_transform(0, NoopResetEnv()) # inserts the NoopResetEnv transform at the index 0
Nevertheless, transforms can access and execute operations on the parent environment:
python1transform = env.transform[1] # gathers the second transform of the list2parent_env = transform.parent # returns the base environment of the second transform, i.e. the base env + the first transform
various tools for distributed learning (e.g. memory mapped tensors)(2);
various architectures and models (e.g. actor-critic)(1):
python1# create an nn.Module2common_module = ConvNet(3 bias_last_layer=True,4 depth=None,5 num_cells=[32, 64, 64],6 kernel_sizes=[8, 4, 3],7 strides=[4, 2, 1],8)9# Wrap it in a SafeModule, indicating what key to read in and where to10# write out the output11common_module = SafeModule(12 common_module,13 in_keys=["pixels"],14 out_keys=["hidden"],15)16# Wrap the policy module in NormalParamsWrapper, such that the output17# tensor is split in loc and scale, and scale is mapped onto a positive space18policy_module = SafeModule(19 NormalParamsWrapper(20 MLP(num_cells=[64, 64], out_features=32, activation=nn.ELU)21 ),22 in_keys=["hidden"],23 out_keys=["loc", "scale"],24)25# Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a26# SafeProbabilisticModule, indicating how to build the27# torch.distribution.Distribution object and what to do with it28policy_module = SafeProbabilisticTensorDictSequential( # stochastic policy29 policy_module,30 SafeProbabilisticModule(31 in_keys=["loc", "scale"],32 out_keys="action",33 distribution_class=TanhNormal,34 ),35)36value_module = MLP(37 num_cells=[64, 64],38 out_features=1,39 activation=nn.ELU,40)41# Wrap the policy and value funciton in a common module42actor_value = ActorValueOperator(common_module, policy_module, value_module)43# standalone policy from this44standalone_policy = actor_value.get_policy_operator()
exploration wrappers and modules to easily swap between exploration and exploitation(1):
python1policy_explore = EGreedyWrapper(policy)2with set_exploration_type(ExplorationType.RANDOM):3 tensordict = policy_explore(tensordict) # will use eps-greedy4with set_exploration_type(ExplorationType.DETERMINISTIC):5 tensordict = policy_explore(tensordict) # will not use eps-greedy
A series of efficient loss modules and highly vectorized functional return and advantage computation.
python1from torchrl.objectives import DQNLoss2loss_module = DQNLoss(value_network=value_network, gamma=0.99)3tensordict = replay_buffer.sample(batch_size)4loss = loss_module(tensordict)
python1from torchrl.objectives.value.functional import vec_td_lambda_return_estimate2advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done, terminated)
a generic trainer class(1) that executes the aforementioned training loop. Through a hooking mechanism, it also supports any logging or data transformation operation at any given time.
various recipes to build models that correspond to the environment being deployed.
LLM API: Complete framework for language model fine-tuning with unified wrappers for Hugging Face and vLLM backends, conversation management with automatic chat template detection, tool integration (Python execution, function calling), specialized objectives (GRPO, SFT), and high-performance async collectors. Perfect for RLHF, supervised fine-tuning, and tool-augmented training scenarios.
python1from torchrl.envs.llm import ChatEnv2from torchrl.modules.llm import TransformersWrapper3from torchrl.envs.llm.transforms import PythonInterpreter45# Create environment with tool execution6env = ChatEnv(7 tokenizer=tokenizer,8 system_prompt="You can execute Python code.",9 batch_size=[1]10).append_transform(PythonInterpreter())1112# Wrap language model for training13llm = TransformersWrapper(14 model=model,15 tokenizer=tokenizer,16 input_mode="history"17)1819# Multi-turn conversation with tool use20obs = env.reset(TensorDict({"query": "Calculate 2+2"}, batch_size=[1]))21llm_output = llm(obs) # Generates response22obs = env.step(llm_output) # Environment processes response
If you feel a feature is missing from the library, please submit an issue! If you would like to contribute to new features, check our call for contributions and our contribution page.
A series of State-of-the-Art implementations are provided with an illustrative purpose:
Algorithm | Compile Support** | Tensordict-free API | Modular Losses | Continuous and Discrete |
DQN | 1.9x | + | NA | + (through ActionDiscretizer transform) |
DDPG | 1.87x | + | + | - (continuous only) |
IQL | 3.22x | + | + | + |
CQL | 2.68x | + | + | + |
TD3 | 2.27x | + | + | - (continuous only) |
TD3+BC | untested | + | + | - (continuous only) |
A2C | 2.67x | + | - | + |
PPO | 2.42x | + | - | + |
SAC | 2.62x | + | - | + |
REDQ | 2.28x | + | - | - (continuous only) |
Dreamer v1 | untested | + | + (different classes) | - (continuous only) |
Decision Transformers | untested | + | NA | - (continuous only) |
CrossQ | untested | + | + | - (continuous only) |
Gail | untested | + | NA | + |
Impala | untested | + | - | + |
IQL (MARL) | untested | + | + | + |
DDPG (MARL) | untested | + | + | - (continuous only) |
PPO (MARL) | untested | + | - | + |
QMIX-VDN (MARL) | untested | + | NA | + |
SAC (MARL) | untested | + | - | + |
RLHF | NA | + | NA | NA |
LLM API (GRPO) | NA | + | + | NA |
** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending on architecture and device.
and many more to come!
Code examples displaying toy code snippets and training scripts are also available
Check the examples directory for more details about handling the various configuration settings.
We also provide tutorials and demos that give a sense of what the library can do.
If you're using TorchRL, please refer to this BibTeX entry to cite this work:
1@misc{bou2023torchrl,2 title={TorchRL: A data-driven decision-making library for PyTorch},3 author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},4 year={2023},5 eprint={2306.00577},6 archivePrefix={arXiv},7 primaryClass={cs.LG}8}
bash1python -m venv torchrl2source torchrl/bin/activate # On Windows use: venv\Scripts\activate
Or create a conda environment where the packages will be installed.
1conda create --name torchrl python=3.92conda activate torchrl
Depending on the use of torchrl that you want to make, you may want to
install the latest (nightly) PyTorch release or the latest stable version of PyTorch.
See here for a detailed list of commands,
including pip3
or other special installation instructions.
TorchRL offers a few pre-defined dependencies such as "torchrl[tests]"
, "torchrl[atari]"
etc.
You can install the latest stable release by using
bash1pip3 install torchrl
This should work on linux (including AArch64 machines), Windows 10 and OsX (Metal chips only). On certain Windows machines (Windows 11), one should build the library locally. This can be done in two ways:
bash1# Install and build locally v0.8.1 of the library without cloning2pip3 install git+https://github.com/pytorch/[email protected]3# Clone the library and build it locally4git clone https://github.com/pytorch/tensordict5git clone https://github.com/pytorch/rl6pip install -e tensordict7pip install -e rl
Note that tensordict local build requires cmake
to be installed via homebrew (MacOS) or another package manager
such as apt
, apt-get
, conda
or yum
but NOT pip
, as well as pip install "pybind11[global]"
.
One can also build the wheels to distribute to co-workers using
bash1python setup.py bdist_wheel
Your wheels will be stored there ./dist/torchrl<name>.whl
and installable via
bash1pip install torchrl<name>.whl
The nightly build can be installed via
bash1pip3 install tensordict-nightly torchrl-nightly
which we currently only ship for Linux machines. Importantly, the nightly builds require the nightly builds of PyTorch too. Also, a local build of torchrl with the nightly build of tensordict may fail - install both nightlies or both local builds but do not mix them.
Disclaimer: As of today, TorchRL is roughly compatible with any pytorch version >= 2.1 and installing it will not directly require a newer version of pytorch to be installed. Indirectly though, tensordict still requires the latest PyTorch to be installed and we are working hard to loosen that requirement. The C++ binaries of TorchRL (mainly for prioritized replay buffers) will only work with PyTorch 2.7.0 and above. Some features (e.g., working with nested jagged tensors) may also be limited with older versions of pytorch. It is recommended to use the latest TorchRL with the latest PyTorch version unless there is a strong reason not to do so.
Optional dependencies
The following libraries can be installed depending on the usage one wants to make of torchrl:
1# diverse2pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher34# rendering5pip3 install "moviepy<2.0.0"67# deepmind control suite8pip3 install dm_control910# gym, atari games11pip3 install "gym[atari]" "gym[accept-rom-license]" pygame1213# tests14pip3 install pytest pyyaml pytest-instafail1516# tensorboard17pip3 install tensorboard1819# wandb20pip3 install wandb
Versioning issues can cause error message of the type undefined symbol
and such. For these, refer to the versioning issues document
for a complete explanation and proposed workarounds.
If you spot a bug in the library, please raise an issue in this repo.
If you have a more generic question regarding RL in PyTorch, post it on the PyTorch forum.
Internal collaborations to torchrl are welcome! Feel free to fork, submit issues and PRs. You can checkout the detailed contribution guide here. As mentioned above, a list of open contributions can be found in here.
Contributors are recommended to install pre-commit hooks (using pre-commit install
). pre-commit will check for linting related issues when the code is committed locally. You can disable th check by appending -n
to your commit command: git commit -m <commit message> -n
This library is released as a PyTorch beta feature. BC-breaking changes are likely to happen but they will be introduced with a deprecation warranty after a few release cycles.
TorchRL is licensed under the MIT License. See LICENSE for details.