more cleanup

This commit is contained in:
patil-suraj 2022-06-30 12:01:27 +02:00
parent 8830af1168
commit b897008122
1 changed files with 0 additions and 14 deletions

View File

@ -909,17 +909,3 @@ def _setup_kernel(k):
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
def contract_inner(x, y):
"""tensordot(x, y, 1)."""
x_chars = list(string.ascii_lowercase[: len(x.shape)])
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
out_chars = x_chars[:-1] + y_chars[1:]
return _einsum(x_chars, y_chars, out_chars, x, y)
def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)