diff --git a/min_dalle.ipynb b/min_dalle.ipynb index c029e63..8b8a41d 100644 --- a/min_dalle.ipynb +++ b/min_dalle.ipynb @@ -130,7 +130,6 @@ "source": [ "dtype = \"float16\" #@param [\"float32\", \"float16\", \"bfloat16\"]\n", "from IPython.display import display, update_display\n", - "from math import log2\n", "import torch\n", "from min_dalle import MinDalle\n", "\n", @@ -204,8 +203,8 @@ " seed=-1,\n", " grid_size=grid_size,\n", " log2_mid_count=log2_mid_count,\n", - " log2_k=int(log2(top_k)),\n", - " log2_supercondition_factor=log2(supercondition_factor)\n", + " top_k=int(top_k),\n", + " supercondition_factor=float(supercondition_factor)\n", ")\n", "\n", "is_first = True\n",