more cleanup
This commit is contained in:
parent
8830af1168
commit
b897008122
|
@ -909,17 +909,3 @@ def _setup_kernel(k):
|
||||||
assert k.ndim == 2
|
assert k.ndim == 2
|
||||||
assert k.shape[0] == k.shape[1]
|
assert k.shape[0] == k.shape[1]
|
||||||
return k
|
return k
|
||||||
|
|
||||||
|
|
||||||
def contract_inner(x, y):
|
|
||||||
"""tensordot(x, y, 1)."""
|
|
||||||
x_chars = list(string.ascii_lowercase[: len(x.shape)])
|
|
||||||
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
|
|
||||||
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
|
|
||||||
out_chars = x_chars[:-1] + y_chars[1:]
|
|
||||||
return _einsum(x_chars, y_chars, out_chars, x, y)
|
|
||||||
|
|
||||||
|
|
||||||
def _einsum(a, b, c, x, y):
|
|
||||||
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
|
|
||||||
return torch.einsum(einsum_str, x, y)
|
|
||||||
|
|
Loading…
Reference in New Issue