|
|
|
@ -130,7 +130,6 @@ |
|
|
|
|
"source": [ |
|
|
|
|
"dtype = \"float16\" #@param [\"float32\", \"float16\", \"bfloat16\"]\n", |
|
|
|
|
"from IPython.display import display, update_display\n", |
|
|
|
|
"from math import log2\n", |
|
|
|
|
"import torch\n", |
|
|
|
|
"from min_dalle import MinDalle\n", |
|
|
|
|
"\n", |
|
|
|
@ -204,8 +203,8 @@ |
|
|
|
|
" seed=-1,\n", |
|
|
|
|
" grid_size=grid_size,\n", |
|
|
|
|
" log2_mid_count=log2_mid_count,\n", |
|
|
|
|
" log2_k=int(log2(top_k)),\n", |
|
|
|
|
" log2_supercondition_factor=log2(supercondition_factor)\n", |
|
|
|
|
" top_k=int(top_k),\n", |
|
|
|
|
" supercondition_factor=float(supercondition_factor)\n", |
|
|
|
|
")\n", |
|
|
|
|
"\n", |
|
|
|
|
"is_first = True\n", |
|
|
|
|