functorch is a prototype of JAX-like composable function transforms for PyTorch.

functorch

Why functorch? | Install guide | Transformations | Documentation | Future Plans

This library is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue or reach out. We'd love to hear about how you're using the library.

functorch is a prototype of JAX-like composable FUNCtion transforms for pyTORCH.

It aims to provide composable vmap and grad transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance. Because this project requires some investment, we'd love to hear from and work with early adopters to shape the design. Please reach out on the issue tracker if you're interested in using this for your project.

In addition, there is experimental functionality to trace through these transformations using FX in order to capture the results of these transforms ahead of time. This would allow us to compile the results of vmap or grad to improve performance.

Why composable function transforms?

There are a number of use cases that are tricky to do in PyTorch today:

  • computing per-sample-gradients (or other per-sample quantities)
  • running ensembles of models on a single machine
  • efficiently batching together tasks in the inner-loop of MAML
  • efficiently computing Jacobians and Hessians
  • efficiently computing batched Jacobians and Hessians

Composing vmap, grad, and vjp transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.

Install

There are two ways to install functorch:

  1. functorch main
  2. functorch preview with PyTorch 1.10

We recommend installing the functorch main development branch for the latest and greatest. This requires an installation of the latest PyTorch nightly.

If you're looking for an older version of functorch that works with a stable version of PyTorch (1.10), please install the functorch preview. On the roadmap is more stable releases of functorch with future versions of PyTorch.

Installing functorch main

Click to expand

Using Colab

Follow the instructions in this Colab notebook

Locally

First, set up an environment. We will be installing a nightly PyTorch binary as well as functorch. If you're using conda, create a conda environment:

conda create --name functorch
conda activate functorch

If you wish to use venv instead:

python -m venv functorch-env
source functorch-env/bin/activate

Next, install one of the following following PyTorch nightly binaries.

# For CUDA 10.2
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html --upgrade
# For CUDA 11.1
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html --upgrade
# For CPU-only build
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --upgrade

If you already have a nightly of PyTorch installed and wanted to upgrade it (recommended!), append --upgrade to one of those commands.

Install functorch:

pip install ninja  # Makes the build go faster
pip install --user "git+https://github.com/pytorch/functorch.git"

Run a quick sanity check in python:

>>> import torch
>>> from functorch import vmap
>>> x = torch.randn(3)
>>> y = vmap(torch.sin)(x)
>>> assert torch.allclose(y, x.sin())

From Source

functorch is a PyTorch C++ Extension module. To install,

  • Install PyTorch from source. functorch usually runs on the latest development version of PyTorch.
  • Run python setup.py install. You can use DEBUG=1 to compile in debug mode.

Then, try to run some tests to make sure all is OK:

pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -v

Installing functorch preview with PyTorch 1.10

Click to expand

Using Colab

Follow the instructions here

Locally

Prerequisite: Install PyTorch 1.10

Next, run the following.

pip install ninja  # Makes the build go faster
pip install --user "git+https://github.com/pytorch/[email protected]/torch_1.10_preview"

Finally, run a quick sanity check in python:

>>> import torch
>>> from functorch import vmap
>>> x = torch.randn(3)
>>> y = vmap(torch.sin)(x)
>>> assert torch.allclose(y, x.sin())

What are the transforms?

Right now, we support the following transforms:

  • grad, vjp, jacrev
  • vmap

Furthermore, we have some utilities for working with PyTorch modules.

  • make_functional(model)
  • make_functional_with_buffers(model)

vmap

Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.

vmap(func)(*inputs) is a transform that adds a dimension to all Tensor operations in func. vmap(func) returns a few function that maps func over some dimension (default: 0) of each Tensor in inputs.

vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with vmap(func), leading to a simpler modeling experience:

>>> from functorch import vmap
>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>>     # Very simple linear model with activation
>>>     assert feature_vec.dim() == 1
>>>     return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = vmap(model)(examples)

grad

grad(func)(*inputs) assumes func returns a single-element Tensor. It compute the gradients of the output of func w.r.t. to inputs[0].

>>> from functorch import grad
>>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())

When composed with vmap, grad can be used to compute per-sample-gradients:

>>> from functorch import vmap
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights,feature_vec):
>>>     # Very simple linear model with activation
>>>     assert feature_vec.dim() == 1
>>>     return feature_vec.dot(weights).relu()
>>>
>>> def compute_loss(weights, example, target):
>>>     y = model(weights, example)
>>>     return ((y - target) ** 2).mean()  # MSELoss
>>>
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights,examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

vjp and jacrev

The vjp transform applies func to inputs and returns a new function that computes vjps given some cotangents Tensors.

>>> from functorch import vjp
>>> outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)

The jacrev transform returns a new function that takes in x and returns the Jacobian of torch.sin with respect to x

>>> from functorch import jacrev
>>> x = torch.randn(5)
>>> jacobian = jacrev(torch.sin)(x)
>>> expected = torch.diag(torch.cos(x))
>>> assert torch.allclose(jacobian, expected)

Use jacrev to compute the jacobian. This can be composed with vmap to produce batched jacobians:

>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacrev(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)

jacrev can be composed with itself to produce hessians:

>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacrev(jacrev(f))(x)

Tracing through the transformations

We can also trace through these transformations in order to capture the results as new code using make_fx. There is also experimental integration with the NNC compiler (only works on CPU for now!).

>>> from functorch import make_fx, grad
>>> def f(x):
>>>     return torch.sin(x).sum()
>>> x = torch.randn(100)
>>> grad_f = make_fx(grad(f))(x)
>>> print(grad_f.code)

def forward(self, x_1):
    sin = torch.ops.aten.sin(x_1)
    sum_1 = torch.ops.aten.sum(sin, None);  sin = None
    cos = torch.ops.aten.cos(x_1);  x_1 = None
    _tensor_constant0 = self._tensor_constant0
    mul = torch.ops.aten.mul(_tensor_constant0, cos);  _tensor_constant0 = cos = None
    return mul

Working with NN modules: make_functional and friends

Sometimes you may want to perform a transform with respect to the parameters and/or buffers of an nn.Module. This can happen for example in:

  • model ensembling, where all of your weights and buffers have an additional dimension
  • per-sample-gradient computation where you want to compute per-sample-grads of the loss with respect to the model parameters

Our solution to this right now is an API that, given an nn.Module, creates a stateless version of it that can be called like a function.

  • make_functional(model) returns a functional version of model and the model.parameters()
  • make_functional_with_buffers(model) returns a functional version of model and the model.parameters() and model.buffers().

Here's an example where we compute per-sample-gradients using an nn.Linear layer:

import torch
from functorch import make_functional, vmap, grad

model = torch.nn.Linear(3, 3)
data = torch.randn(64, 3)
targets = torch.randn(64, 3)

func_model, params = make_functional(model)

def compute_loss(params, data, targets):
    preds = func_model(params, data)
    return torch.mean((preds - targets) ** 2)

per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)

If you're making an ensemble of models, you may find combine_state_for_ensemble useful.

Documentation

For more documentation, see our docs website.

Debugging

functorch._C.dump_tensor: Dumps dispatch keys on stack functorch._C._set_vmap_fallback_warning_enabled(False) if the vmap warning spam bothers you.

Future Plans

In the end state, we'd like to upstream this into PyTorch once we iron out the design details. To figure out the details, we need your help -- please send us your use cases by starting a conversation in the issue tracker or try out the prototype.

License

Functorch has a BSD-style license, as found in the LICENSE file.

Citing functorch

If you use functorch in your publication, please cite it by using the following BibTeX entry.

@Misc{functorch2021,
  author =       {Horace He, Richard Zou},
  title =        {functorch: JAX-like composable function transforms for PyTorch},
  howpublished = {\url{https://github.com/pytorch/functorch}},
  year =         {2021}
}
Comments
  • ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv

    ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv

    Hi All,

    I was running an older version of PyTorch ( - built from source) with FuncTorch ( - built from source), and somehow I've broken the older version of functorch. When I import functorch I get the following error,

    import functorch
    #returns ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv
    

    The version I had of functorch was 0.2.0a0+9d6ee76, is there a way to perhaps re-install to fix this ImportError? I do have the latest version of PyTorch/FuncTorch in a separate conda environment but I wanted to check how it compares to the older version in this 'older' conda environment PyTorch/Functorch were versions ,1.12.0a0+git7c2103a and 0.2.0a0+9d6ee76 respectively.

    Is there a way to download a specific version of functorch with https://github.com/pytorch/functorch.git ? Or another way to fix this issue?

  • Hessian (w.r.t inputs) calculation in PyTorch differs from FuncTorch

    Hessian (w.r.t inputs) calculation in PyTorch differs from FuncTorch

    Hi All,

    I've been trying to calculate the Hessian of the output of my network with respect to its inputs within FuncTorch. I had a version within PyTorch that supports batches, however, they seem to disagree with each other and I have no idea why they don't give the same results. Something is clearly wrong, I know my PyTorch version is right so either there's an issue in my version of FuncTorch or I've implemented it wrong in FuncTorch.

    Also, how can I use the has_aux flag in jacrev to return the jacobian from the first jacrev so I don't have to repeat the jacobian calculation?

    The only problem with my example is that it uses torch.linalg.slogdet and from what I remember FuncTorch can't vmap over .item(). I do have my own fork of pytorch where I edited the backward to remove the .item() call so it works with vmap. Although, it's not the greatest implementation as I just set it to the default nonsingular_case_backward like so,

    Tensor slogdet_backward(const Tensor& grad_logabsdet,
                            const Tensor& self,
                            const Tensor& signdet, const Tensor& logabsdet) {
      auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
        Tensor u, sigma, vh;
        std::tie(u, sigma, vh) = at::linalg_svd(self, false);
        Tensor v = vh.mH();
        // sigma has all non-negative entries (also with at least one zero entry)
        // so logabsdet = \sum log(abs(sigma))
        // but det = 0, so backward logabsdet = \sum log(sigma)
        auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma);
        return svd_backward({}, gsigma, {}, u, sigma, vh);
      };
    
      auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
        // TODO: replace self.inverse with linalg_inverse
        return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().mH();
      };
    
      auto nonsingular = nonsingular_case_backward(grad_logabsdet, self);
      return nonsingular;
    }
    

    My 'minimal' reproducible script is below with the output shown below that. It computes the Laplacian via a PyTorch method and via FuncTorch for a single sample of size [A,1] where A is the number of input nodes to the network.

    import torch
    import torch.nn as nn
    from torch import Tensor
    import functorch
    from functorch import jacrev, jacfwd, hessian, make_functional, vmap
    import time 
    
    _ = torch.manual_seed(0)
    
    print("PyTorch version:   ", torch.__version__)
    print("CUDA version:      ", torch.version.cuda)
    print("FuncTorch version: ", functorch.__version__)
    
    def sync_time() -> float:
      torch.cuda.synchronize()
      return time.perf_counter()
    
    B=1 #batch
    A=3 #input nodes
    
    device=torch.device("cuda")
    
    class model(nn.Module):
    
      def __init__(self, num_inputs, num_hidden):
        super(model, self).__init__()
        
        self.num_inputs=num_inputs
        self.func = nn.Tanh()
        
        self.fc1 = nn.Linear(2, num_hidden)
        self.fc2 = nn.Linear(num_hidden, num_inputs)
      
      def forward(self, x):
        """
        Takes x in [B,A,1] and maps it to sign/logabsdet value in Tuple([B,], [B,])
        """
        
        idx=len(x.shape)
        rep=[1 for _ in range(idx)]
        rep[-2] = self.num_inputs
        g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
        f = torch.cat((x,g), dim=-1)
    
        h = self.func(self.fc1(f))
        
        mat = self.fc2(h)
        sgn, logabs = torch.linalg.slogdet(mat)
        return sgn, logabs
    
    net = model(A, 64)
    net = net.to(device)
    
    fnet, params = make_functional(net)
    
    def logabs(params, x):
      _, logabs = fnet(params, x)
      #print("functorch logabs: ",logabs)
      return logabs
    
    
    def kinetic_pytorch(xs: Tensor) -> Tensor:
      """Method to calculate the local kinetic energy values of a netork function, f, for samples, x.
      The values calculated here are 1/f d2f/dx2 which is equivalent to d2log(|f|)/dx2 + (dlog(|f|)/dx)^2
      within the log-domain (rather than the linear-domain).
    
      :param xs: The input positions of the many-body particles
      :type xs: class: `torch.Tensor`
      """
      xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
      xs_flat = torch.stack(xis, dim=1)
    
      _, ys = net(xs_flat.view_as(xs))
      #print("pytorch logabs: ",ys)
      ones = torch.ones_like(ys)
    
      #df_dx calculation
      (dy_dxs, ) = torch.autograd.grad(ys, xs_flat, ones, retain_graph=True, create_graph=True)
    
    
      #d2f_dx2 calculation (diagonal only)
      lay_ys = sum(torch.autograd.grad(dy_dxi, xi, ones, retain_graph=True, create_graph=False)[0] \
                    for xi, dy_dxi in zip(xis, (dy_dxs[..., i] for i in range(len(xis))))
      )
      #print("(PyTorch): ",lay_ys, dy_dxs)
      
      ek_local_per_walker = -0.5 * (lay_ys + dy_dxs.pow(2).sum(-1)) #move const out of loop?
      return ek_local_per_walker
      
    jacjaclogabs = jacrev(jacrev(logabs, argnums=1), argnums=1)
    jaclogabs = jacrev(logabs, argnums=1)
      
    def kinetic_functorch(params, x):
      d2f_dx2 = vmap(jacjaclogabs, in_dims=(None, 0))(params, x)
      df_dx = vmap(jaclogabs, in_dims=(None, 0))(params, x)
      #print("(FuncTorch): ", d2f_dx2.squeeze(-3).squeeze(-1).diagonal(-2,-1).sum(-1), df_dx)
      #remove the trailing 1's so it's an A by A matrix 
      return -0.5 * d2f_dx2.squeeze(-3).squeeze(-1).diagonal(-2,-1).sum(-1) + df_dx.squeeze(-1).pow(2).sum(-1)
    
    x = torch.randn(B,A,1,device=device) #input Tensor 
    
    print("\nd2f/dx2, df/dx: ")
    t1=sync_time()
    kin_pt = kinetic_pytorch(x)
    t2=sync_time()
    t3=sync_time()
    kin_ft = kinetic_functorch(params, x)
    t4=sync_time()
    
    print("\nWalltime: ")
    print("PyTorch:   ",t2-t1)
    print("FuncTorch: ",t4-t3, "\n")
    
    print("Results: ")
    print("PyTorch: ",kin_pt)
    print("FuncTorch: ",kin_ft)
    

    This script returns

    PyTorch version:    1.12.0a0+git7c2103a
    CUDA version:       11.6
    FuncTorch version:  0.2.0a0+9d6ee76
    
    d2f/dx2, df/dx: 
    
    Walltime: 
    PyTorch:    0.4822753759999614
    FuncTorch:  0.004898710998531897 
    
    Results: 
    PyTorch:  tensor([1.3737], device='cuda:0', grad_fn=<MulBackward0>)    # should be the same values
    FuncTorch:  tensor([7.8411], device='cuda:0', grad_fn=<AddBackward0>) # the jacobian matches, but hessian does not
    

    Thanks for the help in advance! :)

  • add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL

    add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL

    Companion core PR: https://github.com/pytorch/pytorch/pull/77716

    The above PR makes block_diag composite compliant, and this PR adds a batching rule for it.

    Those two changes together should let us fully remove the DECOMPOSE_FUNCTIONAL macro, which was preventing me from moving the Functionalize dispatch key below FuncTorchBatched (which I want to do as part of XX, in order to properly get functionalization working with LTC/XLA).

  • svd-related op regression in functorch

    svd-related op regression in functorch

    https://github.com/pytorch/pytorch/pull/69827 and https://github.com/pytorch/pytorch/pull/70253 caused svd-related tests in functorch to fail:

    • https://app.circleci.com/pipelines/github/pytorch/functorch/1277/workflows/5aaf2c43-6c6a-4ab1-94f7-e0493b8049ff/jobs/7659

    The main problem seems to be that the backward pass uses in-place operations that are incompatible with vmap (aka Composite Compliance problems). There are some other failures that seem to be because some other operations are not Composite Compliant but somehow these weren't a problem previously.

  • functorch doesn't work in debug mode

    functorch doesn't work in debug mode

    It's that autograd assert that we run into often:

    import torch
    from functorch import make_fx
    from functorch.compile import nnc_jit
    
    
    def f(x, y):
        return torch.broadcast_tensors(x, y)
    
    
    inp1 = torch.rand(())
    inp2 = torch.rand(3)
    
    print(f(inp1, inp2))  # without nnc compile everything works fine
    
    print(make_fx(f)(inp1, inp2))  # fails
    print(nnc_jit(f)(inp1, inp2))
    # RuntimeError: self__storage_saved.value().is_alias_of(result.storage())INTERNAL ASSERT FAILED at "autograd/generated/VariableType_3.cpp":3899, please report a bug to PyTorch.
    

    cc @albanD @soulitzer what's the chance we can add an option to turn these off? They've been more harmful (e.g. prevent debugging in debug mode) than useful for us.

  • Index put vmap internal assert

    Index put vmap internal assert

    import torch
    from functorch import vmap
    self = torch.randn(4, 1, 1).cuda()
    idx = (torch.tensor([0]).cuda(),)
    value = torch.randn(1, 1).cuda()
    
    def foo(x):
        return x.index_put_(idx, value, accumulate=True)
    
    vmap(foo)(self)
    
    RuntimeError: linearIndex.numel()*sliceSize*nElemBefore == value.numel()INTERNAL ASSERT FAILED at "/raid/rzou/pt/debug-cuda/aten/src/ATen/native/cuda/Indexing.cu":249, please report a bug to PyTorch. number of flattened indices did not match number of elements in the value tensor41
    
  • Batching rule not implemented for aten::item.

    Batching rule not implemented for aten::item.

    Hey, I would like to use functorch.vmap in a custom PyTorch activation function (the gradients are not needed, because the backward-pass is calculated differently). During the computation of the activation function, I do a lookup in a tensor X using a tensor Y.item() call, similar to the small dummy code below.

    Unfortunately I get the error message: RuntimeError: Batching rule not implemented for aten::item. We could not generate a fallback.

    Is it not possible to do an item() call in a vmap function or is something else wrong? Thanks a lot!

    import torch
    from functorch import vmap
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    sum = torch.zeros([100, 10], dtype=torch.int32).to(device)
    lookup = torch.randint(100, (20, 1000, 10)).to(device)
    input_tensor = torch.randint(1000, (100, 20)).to(device)
    
    def test_fun(sum, input_tensor):
      for j in range(20):
        for i in range(10):
          sum[i] += lookup[j, input_tensor[j].item(), i]
      return sum
    
    # non-vectorized version
    for i in range(100):
      test_fun(sum[i], input_tensor[I])
    
    # vectorized version throws error
    test_fun_vec = vmap(test_fun)
    test_fun_vec(sum, input_tensor)
    
  • torch.atleast_1d batching rule implementation

    torch.atleast_1d batching rule implementation

    Hi functorch devs! I'm filing this issue because my code prints the following warning:

    UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::atleast_1d. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at  /tmp/pip-req-build-ytawxmfk/functorch/csrc/BatchedFallback.cpp:106.)
    

    Why Am I Using atleast_1d ?

    I'm subclassing torch.Tensor because my code needs to be able to add some extra data to that class (I'm integrating PyTorch's AD system with another AD system to be able to call torch functions from inside a PDE solve, which is why I also inherit from a class called OverloadedType), which is named _block_variable; e.g. the subclass looks like

    class MyTensor(torch.Tensor, OverloadedType):
        _block_variable = None
    
        @staticmethod
        def __new__(cls, x, *args, **kwargs):
            return super().__new__(cls, x, *args, **kwargs)
    
        def __init__(self, x, block_var=None):
            super(OverloadedType, self).__init__()
            self._block_variable = block_var or BlockVariable(self)
            
    
        def to(self, *args, **kwargs):
            new = Tensor([])
            tmp = super(torch.Tensor, self).to(*args, **kwargs)
            new.data = tmp.data
            new.requires_grad = tmp.requires_grad
            new._block_variable = self._block_variable
            return new
    
         ... #some subclass-specific methods etc
    

    This causes problems when I have code that does stuff like torch.tensor([torch.trace(x), torch.trace(x @ x)]) where x is a square MyTensor; the torch.tensor() call raises an exception related to taking the __len__ of a 0-dimentional tensor (the scalar traces). So instead, I do torch.cat([torch.atleast_1d(torch.trace(x)), torch.atleast_1d(torch.trace(x @ x))]), which works. However, this function is functorch.vmap-ed, which triggers the performance warning. It would be great if I could either get the naive implementation (using torch.tensor instead of torch.cat) to work, or if a batch rule for atleast_1d() were to be implemented.

    Thank you for any help you can provide!

  • Top 25 OpInfos for functorch

    Top 25 OpInfos for functorch

    We'd love help on these.

    The check box is for if the OpInfo has added to PyTorch core. The ultimate goal is for all of these OpInfos to exist in PyTorch core. The OpInfo is bolded if we have a poor man's version* of the OpInfo in the functorch repo (see https://github.com/facebookresearch/functorch/blob/main/test/functorch_additional_op_db.py).

    Exists

    • [x] torch.nn.functional.softmax (https://github.com/pytorch/pytorch/pull/62077)
    • [x] torch.nn.functional.relu (https://github.com/pytorch/pytorch/pull/62076)
    • [x] torch.nn.functional.interpolate (https://github.com/pytorch/pytorch/pull/61956)
    • [x] torch.nn.functional.pad (https://github.com/pytorch/pytorch/pull/62814)
    • [x] torch.nn.functional.normalize (https://github.com/pytorch/pytorch/pull/62635)
    • [x] torch.nn.functional.cross_entropy (https://github.com/pytorch/pytorch/pull/63547)
    • [x] torch.nn.functional.grid_sample (https://github.com/pytorch/pytorch/pull/62311)
    • [x] torch.nn.functional.one_hot (https://github.com/pytorch/pytorch/pull/62253)
    • [x] torch.nn.functional.mse_loss
    • [x] torch.nn.functional.conv2d (https://github.com/pytorch/pytorch/pull/63517)
    • [x] torch.nn.functional.dropout (https://github.com/pytorch/pytorch/pull/62315)
    • [x] torch.nn.functional.softplus (https://github.com/pytorch/pytorch/pull/62317)
    • [x] torch.nn.functional.linear (https://github.com/pytorch/pytorch/pull/61971)
    • [x] torch.nn.functional.avg_pool2d (https://github.com/pytorch/pytorch/pull/62455)
    • [x] torch.nn.functional.max_pool2d (https://github.com/pytorch/pytorch/pull/63530)
    • [x] torch.nn.functional.nll_loss (https://github.com/pytorch/pytorch/pull/64203)
    • [x] torch.nn.functional.embedding (https://github.com/pytorch/pytorch/pull/63633)
    • [x] torch.nn.functional.adaptive_avg_pool2d (https://github.com/pytorch/pytorch/pull/62704)
    • [x] torch.nn.functional.cosine_similarity (https://github.com/pytorch/pytorch/pull/62959)
    • [x] torch.nn.functional.unfold https://github.com/pytorch/pytorch/pull/62705
    • [x] torch.nn.functional.batch_norm (https://github.com/pytorch/pytorch/pull/63218)
    • [x] torch.nn.functional.conv_transpose2d https://github.com/pytorch/pytorch/pull/62882
    • [x] torch.nn.functional.layer_norm https://github.com/pytorch/pytorch/pull/63276

    *Why do we have poor man's version of these OpInfos? It's because right now we only care about float32 sample inputs on CPU and CUDA and OpInfos have a lot of flags that take some time to tweak.

  • Memory Leak

    Memory Leak

    Hello! I am thrilled with the functorch package, and have been playing with it lately.

    With @soumik12345 we found a memory leak after training a NN. We documented our findings here:

    http://wandb.me/functorch-intro

    We are probably doing something wrong, but the memory increases after each epoch.

    image

    As the GPU is pretty monstrous we didn't notice this straight away, but it clearly fills up progresively. The stateful pytorch training loop does not produce this.

  • RuntimeError: CUDA error: no kernel image is available for execution on the device

    RuntimeError: CUDA error: no kernel image is available for execution on the device

    Hi, I have cuda 11.7 on my system and I am trying to install functorch, since the stable version of pytorch for cuda 11.7 is not available here, I just run pip install functorch which also installs the compatible version of pytorch.

    But when I run my code that uses the GPU, I get the following error :

    RuntimeError: CUDA error: no kernel image is available for execution on the device

    Is it possible to use functorch in my case?

  • Speed up functorch tests in CI

    Speed up functorch tests in CI

    Here's roughly how the vmap testing works (pseudocode):

    all_permutations = create_all_batched_inputs()
    for batched_inputs in all_permutations:
      result = vmap(op)(batched_inputs)
      expected = for_loop_over_op(op, batched_inputs)
      assert torch.allclose(result, expected)
    

    There are two things that could be optimized:

    1. for_loop_over_op(op, batched_inputs). Instead of running a for-loop over the op with slices of batched_inputs, we can just run the op on the original input and expand the result.
    2. expected is the SAME across all iterations of the for loop in the example above. So we only need to compute it once.

    The end-state should look something like:

    all_permutations = create_all_batched_inputs()
    expected = expand(op(inputs))
    for batched_inputs in all_permutations:
      result = vmap(op)(batched_inputs)
      assert torch.allclose(result, expected)
    

    This will probably save us something like 50% of runtime on our vmap tests.

  • functorch doesn't work with saved variable hooks

    functorch doesn't work with saved variable hooks

    I'm not sure why, cc @soulitzer @albanD.

    Repro:

    import torch
    from torch.autograd.graph import save_on_cpu
    from functorch import grad
    
    x = torch.randn([], device='cuda', requires_grad=True)
    
    def f(x):
        return x.sin().sin()
    
    with save_on_cpu():
        y = f(x)
        gx, = torch.autograd.grad(y, x, create_graph=True)
    assert gx.requires_grad
    
    with save_on_cpu():
        gx = grad(f)(x)
    
    # Fails
    assert gx.requires_grad
    
  • Printing unwrapped tensors doesn't work under grad-type transform

    Printing unwrapped tensors doesn't work under grad-type transform

    Repro

    x = torch.randn(3)
    def f(y):
      print(x)
      return y
    vjp(f, torch.randn(4)) # or equivalent with grad or jvp
    

    Gets error: RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

    Full Stacktrace
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/Users/samdow/Documents/jacfwd_fix/functorch/functorch/_src/eager_transforms.py", line 270, in vjp
        primals_out = func(*diff_primals)
      File "<stdin>", line 2, in f
      File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor.py", line 423, in __repr__
        return torch._tensor_str._str(self, tensor_contents=tensor_contents)
      File "/Users/samdow/Documents/jacfwd_fix/functorch/functorch/_src/monkey_patching.py", line 23, in _functorch_str
        return _old_str(tensor)
      File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor_str.py", line 594, in _str
        return _str_intern(self, tensor_contents=tensor_contents)
      File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor_str.py", line 557, in _str_intern
        tensor_str = _tensor_str(self, indent)
      File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor_str.py", line 320, in _tensor_str
        return _tensor_str_with_formatter(self, indent, summarize, formatter)
      File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor_str.py", line 249, in _tensor_str_with_formatter
        return _vector_str(self, indent, summarize, formatter1, formatter2)
      File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor_str.py", line 230, in _vector_str
        data = [_val_formatter(val) for val in self.tolist()]
    RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
    

    What we think is happening

    There's some monkey patching to make printing work. We think that since the transform is still on, whatever functions normal printing calls are getting wrapped in tensor wrappers that don't have storage

    Potential Solutions

    An option could be to turn off all the transforms while printing (or even when just printing an unwrapped tensor)

  • Change vmap to work around

    Change vmap to work around "vmap-incompatible in-place errors"

    e.g. https://pytorch.org/functorch/stable/ux_limitations.html#mutation-in-place-pytorch-operations

    There are a couple of options here:

    • Always expand_copy factory function calls. This leads to another set of tradeoffs: a program that would have worked today would not work anymore, as well performance differences (if the expanded tensor didn't need to be expanded, then that's a perf hit).
    • Have torch.zeros return something lazy that later determines if it needs to be expanded or not. This needs design and I'm not sure if it can avoid all the edge cases
  • Get .item() error without calling .item()

    Get .item() error without calling .item()

    Hello guys, I'm new to this package and I want to calculate batched Jacobian w.r.t a self-implemented vector function. But I got the following error when I'm doing this.

    RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

    Here is my code. I don't understand where the .item() comes from. Is this slicing operation q_current[0:3] wrong? How can I fix this?

    import torch
    from functorch import jacrev,vmap
    
    #batch * len
    q_current = torch.randn((4,4*3-1),requires_grad=True)
    
    
    def geoCompute(q_current):
        k1 = q_current[0:3]
        return k1
    
    
    jacobian = vmap(jacrev(geoCompute))(q_current)
    
functorch is a prototype of JAX-like composable function transforms for PyTorch.

functorch Why functorch? | Install guide | Transformations | Documentation | Future Plans This library is currently under heavy development - if you h

Oct 3, 2022
JBDL: A JAX-Based Body Dynamics Algorithm Library forRobotics

JBDL: A JAX-Based Body Dynamics Algorithm Library forRobotics

Aug 1, 2022
Blazing fast, composable, Pythonic quantile filters.
Blazing fast, composable, Pythonic quantile filters.

Rolling Quantiles for NumPy Hyper-efficient and composable filters. Simple, clean, intuitive interface. Supports streaming data or bulk processing. Py

Aug 2, 2022
A c++ trainable semantic segmentation library based on libtorch (pytorch c++). Backbone: ResNet, ResNext. Architecture: FPN, U-Net, PAN, LinkNet, PSPNet, DeepLab-V3, DeepLab-V3+ by now.
A c++ trainable semantic segmentation library based on libtorch (pytorch c++). Backbone: ResNet, ResNext. Architecture: FPN, U-Net, PAN, LinkNet, PSPNet, DeepLab-V3, DeepLab-V3+ by now.

中文 C++ library with Neural Networks for Image Segmentation based on LibTorch. The main features of this library are: High level API (just a line to cr

Sep 28, 2022
This is a code repository for pytorch c++ (or libtorch) tutorial.

LibtorchTutorials English version 环境 win10 visual sutdio 2017 或者Qt4.11.0 Libtorch 1.7 Opencv4.5 配置 libtorch+Visual Studio和libtorch+QT分别记录libtorch在VS和Q

Sep 27, 2022
GPU PyTorch TOP in TouchDesigner with CUDA-enabled OpenCV

PyTorchTOP This project demonstrates how to use OpenCV with CUDA modules and PyTorch/LibTorch in a TouchDesigner Custom Operator. Building this projec

Jun 15, 2022
Deep Learning API and Server in C++11 support for Caffe, Caffe2, PyTorch,TensorRT, Dlib, NCNN, Tensorflow, XGBoost and TSNE

Open Source Deep Learning Server & API DeepDetect (https://www.deepdetect.com/) is a machine learning API and server written in C++11. It makes state

Sep 23, 2022
Fast, differentiable sorting and ranking in PyTorch
Fast, differentiable sorting and ranking in PyTorch

Torchsort Fast, differentiable sorting and ranking in PyTorch. Pure PyTorch implementation of Fast Differentiable Sorting and Ranking (Blondel et al.)

Sep 23, 2022
Official PyTorch Code of GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection (CVPR 2021)
Official PyTorch Code of GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection (CVPR 2021)

GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Mo

Jul 31, 2022
Support Yolov4/Yolov3/Centernet/Classify/Unet. use darknet/libtorch/pytorch to onnx to tensorrt
Support Yolov4/Yolov3/Centernet/Classify/Unet. use darknet/libtorch/pytorch to onnx to tensorrt

ONNX-TensorRT Yolov4/Yolov3/CenterNet/Classify/Unet Implementation Yolov4/Yolov3 centernet INTRODUCTION you have the trained model file from the darkn

Sep 23, 2022
UE4 Plugin to execute trained PyTorch modules

SimplePyTorch UE4 Plugin to execute trained PyTorch modules ------- Packaging ------- Download PyTorch C++ distributions: https://pytorch.org/cppdocs/

Aug 19, 2022
C++ trainable detection library based on libtorch (or pytorch c++). Yolov4 tiny provided now.
C++ trainable detection library based on libtorch (or pytorch c++). Yolov4 tiny provided now.

C++ Library with Neural Networks for Object Detection Based on LibTorch. ?? Libtorch Tutorials ?? Visit Libtorch Tutorials Project if you want to know

Sep 15, 2022
A simple demonstration of how PyTorch autograd works

简单地演示了 PyTorch 中自动求导机制的原理。 官方博客:https://pytorch.org/blog/overview-of-pytorch-autograd-engine/ 编译运行 使用 Bazel bazel run autograd_test 包含了一个使用 MSE 损失函数的一

Feb 24, 2022
An inofficial PyTorch implementation of PREDATOR based on KPConv.

PREDATOR: Registration of 3D Point Clouds with Low Overlap An inofficial PyTorch implementation of PREDATOR based on KPConv. The code has been tested

Aug 3, 2022
DLPrimitives/OpenCL out of tree backend for pytorch

Pytorch OpenCL backend based on dlprimitives DLPrimitives-OpenCL out of tree backend for pytorch It is only beginning, but you can train some vision n

Sep 29, 2022
A external memory allocator example for PyTorch.

Custom PyTorch Memory Management This is a external memory allocator example for PyTorch. The underlying memory allocator is CNMeM. Usage Compile with

Aug 2, 2022
A LLVM-based static analyzer to produce PyTorch operator dependency graph.

What is this? This is a clone of the deprecated LLVM-based static analyzer from the PyTorch repo, which can be used to produce the PyTorch operator de

Dec 15, 2021
PSTensor provides a way to hack the memory management of tensors in TensorFlow and PyTorch by defining your own C++ Tensor Class.

PSTensor : Custimized a Tensor Data Structure Compatible with PyTorch and TensorFlow. You may need this software in the following cases. Manage memory

Feb 12, 2022
Official Pytorch implementation of RePOSE (ICCV2021)
Official Pytorch implementation of RePOSE (ICCV2021)

RePOSE: Fast 6D Object Pose Refinement via Deep Texture Rendering (ICCV2021) [Link] Abstract We present RePOSE, a fast iterative refinement method for

Sep 29, 2022