From 67151c73f751739638f91472c18b2af5f3789894 Mon Sep 17 00:00:00 2001 From: bishe <123456789@163.com> Date: Sun, 23 Feb 2025 23:15:25 +0800 Subject: [PATCH] without cnt --- checkpoints/ROMA_UNSB_001/loss_log.txt | 22 ++++++ checkpoints/ROMA_UNSB_001/train_opt.txt | 8 +- models/__pycache__/networks.cpython-39.pyc | Bin 51012 -> 51061 bytes .../roma_unsb_model.cpython-39.pyc | Bin 18663 -> 19104 bytes models/networks.py | 2 + models/roma_unsb_model.py | 73 ++++++++++-------- .../__pycache__/base_options.cpython-39.pyc | Bin 7598 -> 7610 bytes options/base_options.py | 2 +- scripts/train.sh | 2 +- 9 files changed, 73 insertions(+), 36 deletions(-) diff --git a/checkpoints/ROMA_UNSB_001/loss_log.txt b/checkpoints/ROMA_UNSB_001/loss_log.txt index d27cb3a..fd8dd2f 100644 --- a/checkpoints/ROMA_UNSB_001/loss_log.txt +++ b/checkpoints/ROMA_UNSB_001/loss_log.txt @@ -46,3 +46,25 @@ ================ Training Loss (Sun Feb 23 22:33:48 2025) ================ ================ Training Loss (Sun Feb 23 22:39:16 2025) ================ ================ Training Loss (Sun Feb 23 22:39:48 2025) ================ +================ Training Loss (Sun Feb 23 22:41:34 2025) ================ +================ Training Loss (Sun Feb 23 22:42:01 2025) ================ +================ Training Loss (Sun Feb 23 22:44:17 2025) ================ +================ Training Loss (Sun Feb 23 22:45:53 2025) ================ +================ Training Loss (Sun Feb 23 22:46:48 2025) ================ +================ Training Loss (Sun Feb 23 22:47:42 2025) ================ +================ Training Loss (Sun Feb 23 22:49:44 2025) ================ +================ Training Loss (Sun Feb 23 22:50:29 2025) ================ +================ Training Loss (Sun Feb 23 22:51:47 2025) ================ +================ Training Loss (Sun Feb 23 22:55:56 2025) ================ +================ Training Loss (Sun Feb 23 22:56:19 2025) ================ +================ Training Loss (Sun Feb 23 22:57:58 2025) ================ +================ Training Loss (Sun Feb 23 22:59:09 2025) ================ +================ Training Loss (Sun Feb 23 23:02:36 2025) ================ +================ Training Loss (Sun Feb 23 23:03:56 2025) ================ +================ Training Loss (Sun Feb 23 23:09:21 2025) ================ +================ Training Loss (Sun Feb 23 23:10:05 2025) ================ +================ Training Loss (Sun Feb 23 23:11:43 2025) ================ +================ Training Loss (Sun Feb 23 23:12:41 2025) ================ +================ Training Loss (Sun Feb 23 23:13:05 2025) ================ +================ Training Loss (Sun Feb 23 23:13:59 2025) ================ +================ Training Loss (Sun Feb 23 23:14:59 2025) ================ diff --git a/checkpoints/ROMA_UNSB_001/train_opt.txt b/checkpoints/ROMA_UNSB_001/train_opt.txt index d7766e4..4d2cd07 100644 --- a/checkpoints/ROMA_UNSB_001/train_opt.txt +++ b/checkpoints/ROMA_UNSB_001/train_opt.txt @@ -1,5 +1,5 @@ ----------------- Options --------------- - atten_layers: 1,3,5 + atten_layers: 5 batch_size: 1 beta1: 0.5 beta2: 0.999 @@ -28,10 +28,12 @@ init_type: xavier input_nc: 3 isTrain: True [default: None] + lambda_D_ViT: 1.0 lambda_GAN: 8.0 [default: 1.0] lambda_NCE: 8.0 [default: 1.0] lambda_SB: 0.1 lambda_ctn: 1.0 + lambda_global: 1.0 lambda_inc: 1.0 lmda_1: 0.1 load_size: 286 @@ -50,7 +52,7 @@ nce_includes_all_negatives_from_minibatch: False nce_layers: 0,4,8,12,16 ndf: 64 - netD: basic + netD: basic_cond netF: mlp_sample netF_nc: 256 netG: resnet_9blocks_cond @@ -78,7 +80,7 @@ nce_includes_all_negatives_from_minibatch: False serial_batches: False stylegan2_G_num_downsampling: 1 suffix: - tau: 0.1 [default: 0.01] + tau: 0.01 update_html_freq: 1000 use_idt: False verbose: False diff --git a/models/__pycache__/networks.cpython-39.pyc b/models/__pycache__/networks.cpython-39.pyc index 40fbb5c3111eb70879d74788f829d845264fd328..91353a9c2e5846b90c8a57791cb979eb0cf5682c 100644 GIT binary patch delta 5643 zcmb7I33QZ25>C&ZV?tmO5@H6z6*EhKB*-O~9C9UwE1&_wWRjVLn8_scOd{ZffId)B zu9S0a;)PG~VEd~-$AGN6%X=xRyQ`|}tFEr< ze-5wKUs$cTx~$fiDE06Aq5CS{++qD=w%ALLW@m|obS(Rt#<@YA+ft>I^8&hHR8RzC z-O;P`4KcoP?wFviK1wbQ=y_frGzX*Iv8#+~EmP^lDV_BJ1J>MzU~JIrHu73r!05IH zb)BQ`6nBV#DHszp22DW=TNCd1RmK@njt+IBOi@u|rNq2|-X(dx6qN4m3#aM*F z6rW=0C{Kq>V{bHwpDBOBO7SfnpD;rFKuHq^=U#xxW{|@m-z(yXlkcnlW=E^&c%5!n zz%@I%+*RdaR$ihd6Ymj^lVws4W#=W)`61C%vd~IrW4qJxNz=L{s=Z}jr5`rd*4*Zi z4xJpVMly+BSez_gBlF}bI6;TD_{n;sZ*Zaz-SQZt7fdKIEKQAw(LzG_6ovq_pW< zMFt(6K1P^B(KD=KU~lX+05)S>Znv|##$D(2$U&IR02z)zSSx-0fXi3raRQ9&MR|oi z6Z&I08)P6z4&`h}7FJqcIMoDuauD?|?qNKx^uHwA%wJ3z9O^u4wy@KJS?85cuA8&P z_0)88Zvkza5qg<)?B-mNN%~vrbB3w0!zABC_^*deHauK$VRMBippZ+M1(|5=h(^-3 zTL#3Bf@BVa^wQ&}Wrc}iKvPkUIa5s)IL!1$VUidC=cu;GpuDInS|Bn*XA~8qsc6m| zF@bi^$r1T*gjW<$=8kqyS_x$VhsAZf`M#KAhV}`P`wnn^NmvoFQ(Y zS;by8^I~yDr!h(}tgm-@97dT1C{t< zm6egs-pNzdv{To1Euuu!rFfGvm6Mkkk;2_fa=H>6Pk%2<4SneJiYyE{o#ixeoIbH* zYkc*&?lIn)vMSjJ)A{f#oo;ufj_->xXW@9owT5<*;=|FEzS_Eg(_L9Dzf%gcqy2$J zUUjY1m@glYYTM*8I4^zcdMUMSU-<)e;R|~+ByXX&O2>(%)TL~a4Q^He9^WDnpD3rd zL2D`-AjZp{up1xRcyNq<1QcRG z_#o5h-25~Lj!5DRuY~v-7II#AoQSq%MHo(~)fxm;Ccp3g175+<&B`~S}qGTM1 z3&R9($Ou`AVJ~on;5vgtOO|dXYR(a(-L4z4-ku$Z;FaOhf&`sd&HHi@)h9|zjQ<`r zS;WcxGM)4`B;!8G`7qUTeK{E_8GV>mR!_IVpO+%gcak@@p-102W$LDr#dN)|scI&S zaTel;RC;9r93yB4#F=hs&Ghy}Z^_1RjIXX*OBH`af?saBCCfjdDK)cHHAWBmf?lk- z$%ef?j&RnjJnLDzrhEI%HN(eLP z1zbgQ7Hpl$XNePZm^G|bu+CUHuiFlaZ=Mq3T;n4;jL=+xN^7q_lUJ6BYC0xUMJgYJ zf%N|AHrW0t0-HrvUFnyRuT#^)#L%n$e+f}akJe2SZn{v{Pt2lT_1B6R%CGNg>IW}p zQf+;Y6jaNwK|Q+aRsOMeTv^7`w)%82lm1qpA>!zp`aKCtpfd` ztX_ppz39{nLKXHbIjhyRP|3pk#EjdZ?Sl3J5WWqZ`D#8&)?)P_xCP++APpD>z}11c zLFyH8bkrifj7BXUI%yH6xo~KhFkB3BJ4g_uL6O$TYuGT`9{zdctw_m{n1#4p#mUh?$CmUHE9laa zt~L%wgikTb3Dsaj;~MjXdLyh|`j)}o2>Uakzh03-ty%%YR)cJ$Wy||o5Rt7$HQ2Mf zw`ipImZysj^vm*|vD{iY(0!=Sikw)4R%^9(s#xI=MC(@AuHk^%)lxX#=XBTH=JUI% zYrU1e3UzA_tgwqUp%W|a(iL$V;?@LNtv{hFX#hJw-{7tE2UZ5=plF3Z7HfIL83R@w3+z+<^5RyKeGhOm|0a(aAaHExfT8?n$2BGW)30)>S? zakl*!N?AQkH}EQP2%DtK=dP&^o9ld)<#@7gff}EDcW^yu^V+Nto54N|QlQA}=xUe0 z$|+p|kGx7;!Un8_#TeNi8rfuCH$Z5y8bLGGmFON3^g({I5ux~&-dGn_`O@6DP7R>E{hj6@s}NkL!o8Rl*S$MMICc6tq^O?P_p!;K-eB|1o#? zg=eSVg_aB3n>b~jBC)wKb}J@;Pi|viZ_%Th+knB7$1VewCATYJc1n17u-Fiq`0)Kg zY@id5^h@Fr#1Zu%q9G0|@AF=0$d(87VlBP(Sg)aN&WFw8WKC^NU4ZX=>+A8VR^Z&z zl=FCJW3yt`&N`v!xlIGs#GZci?)o3rmb71n&2f zpX13*ic_9%02hd7psQsBY}+J#xua_;gJiQiu^xF;89T?R*AG7HhtZuoUr%5uuBLe4 zNFVp;05ycNThhJ@935y>A;yJPKTa5RBJNWBP`Rrm$oMItFFAJY>VT(7*sRL7$0c*+ zGjwiOj@Zw81!>UkRSAEEGyHzp4{jgtbr*fLyMMx7Oth1=B=YRGXVv8G*cnwswaZxf zY}=l}J;T&*2z&QrsN+4gXPCZmH!N`6Phi+T+Z%niC=j-S!Y$ zwZ6(}a|%mg5D!;7t;*1&&-E9#GVUh&{JCwWYPhD|yTcUPm)TlS;ME%7<42G%=u&d! z3H&0(U$+_i2aA(5Z~s7r@c#X0MI=JxZVN)IhVUe2{x5`2LIb#3p9<}eIAOS24uw!U zNOwB6*rI+5G8`zk9)y;VY(c~I5v@6JZ7VML@=7GG#1wh54cDw@8?D)+?##hHrsolv zOggovGu0magXl>Q9!wLbXz#%cagI(Myf})#ocYPaO)MJnS#akDMOr`9xv4#YC9O|^ zdlKXVr5v)0^EBwtfL7mLQ1hQg`j(=XFQ}}VKqn7XiKBdK7pe5c*ZXjh>I^f0kE6~vB72{r6097;&%`JO+)f$|3nP=6R#NDe2^*-Jfh_?kVPOb!wf%5xTm73 nYp)H|D9QB(H%jtl;Cw+aB;3@FI3057QMCVPp16)IFAe`UTp^N% delta 5603 zcmb7I3s}@u7U#~3;V~mQsEC8$+fhV7P043Sz8|3Y07V#?fe;w*GJuak*e&JD@_Dwb zO)YEFMXl6P+D&&QAKCV0wSBg=W0t06*7ml~vUc73JAaU=b=%hO^KkCH=bn3h=bp#? z&o}G!C)VpNPKzZxMEx82-SfqR_gfxJ7F+4b5J7moW1I$cG@8se5`{ z#1*h#f_y>~()&h#2KIAB0=mUj@(XRDA>A$-X@g${9P7xjH&ZPtBqlg~VJUOeYO`R$Ssht61Wks!9I&^Y4 zbVgIk`Yz&S8Zl+7h^9BE3>UAGWoo(&@-M+sgOW1qtOjW#Q5sE|(}&XTsgHNUdcahc z>Zy?BO|81@tm{pxot9&Y!+Cbv`W8%$zfEG*VFix+W!yAnDbAQ28Slj%~jSsG52#t zH_uBJX>@#EAA#LABlh(rF+WxGqoMO_QtnV=he^JHYe<7mRyNS0R^jAY{U{ zs~Syb=MS-tfn+j-^wM2LU)Dy86iUfSiR-84W;jA+I`oYP$pFEf%T%`ZZ;FkjojFTI zKYvW_QX$6D+64>51p076ikLz%3xAh79@?jZG$~<+nNJQFOUj+aZaD?>(?F(!Oa>WE zpDpYorc-p@Km}`DUc8v9&gv`1P-9-1nz@=+-0@B&7|_=^-43HnggG;*e9=Ua$@jj9 zPA{6(ah4Ks=rx=(pmz#8+(nb~`zS3H`F*;~RMUC7kha_j5g|9MCZ$S?@_3WXQUVj` zMq#W!x}Z!XVyL0d1jmsTF}}i6lj=$@t0*j$128=e_Qca(XY7PTj5$E>Qe2zYCn-J< zR^q9w_U5}v%H`K;LvC1=cX^r0IW^|Y1VYKjsJl@4w1j#uJ{5cOTzL(ra5m3{adfB#%Mgd~5=sAQU4oOu<4LxP%ZIwLVg;rE>R`Q8{W_-629lB2L+< z9HVlPnrjUgnThrGIN^vLO`D7Qi!F4xXn_r;-xmbgBGVP*X<{PHDDCgxUVNJnvuWSm ziFZ!K8H{nE;sTK>;5}|{{$jqqW}L!zH~}0xU=%ilf$+U{qqq`#JOjj){uJda*q1`{ zOS+5QR9kXr%3=tq|42)h`WuKSbb9lP+)i(`bZckvsbe9BW0V|MQ6LV8)aOH$`rmrP1sFzQh;Wcj410s)Am|JZy;izF3e{>` zIAlPwJsjY#pE^f#3HC&2R=Y*Wzz{7G{)dmCZ9Yf#i`1dKt_z1;F;unYEMNKEjQVXl zTt3SNdrk(S?kF=?p=E8$F1 zLL5!)8Tk_+T7$h#+tk&Uk*JClUD@S28d^2W2HoEU!7cfD z>MT)6eQT1$9GX+pPlQojO?T5ESe8vYYI???5C#l|s#L1zVvD#?+(l<=5=1urP}5h0 zQ?J_Qs56Ld(YQ?NmwK`GG;aXkki}(L&S%H{cgWXPS zJPN|G&T*~gLu3V3TfkLns*zhaz)zYWtiq7pkcyrImil- zl^{Mv+Ptk{!>k4ZTrEgFhs`(&U%Au%s~@d51xLMh1%-5JKxrz>~9%P0inv802d37JvuqK(p&@a}+iFJN!W1}t}Zyv43KvrusXfKkX zTIt|@39(Kkra(;q1J2M{!`W zx7gF{bLBTnTZk#u4l8Y2(>uOB&TdAR{4Gv?6hy21<#ce%4ojJuX!}yx#;LN6Ry^EO zgW3t!mMBnKy%Gk!Jz4JJFP>-f!f5d32{h=DeC399kG!R?-=KtpvPJU%d^4l1675!l zUjqK;qv>XlIILR6pN7|8rKUA8{~b8L-5%wj+DWO~>mz;-TYz3RG3+`%Ms#_c9U-P1`vn z;yTtj(DvXnrwd}~EnsT>$ks(hnET5u3RN7P-fg9=AL2r#8OX)&a{0`&VTqbiDLJ*`qTKbQevary=B7hKehMohI&p8 z^sQ{wG*;oTWSGlE_h~hBbkwrNk0l0i0Nl@~4EvT_8Yc?p4VRT?u(v&GY?9&+bdP0> ztad-vgRgDgfeiH|0=6_dc;IvtOL1Akt4Ee|N0;=BKUvF@uCQ~6Q8^b?K>dPY)QQL& z)->sF>+O_t@K6W5TLPh~w7Z=$RX#&;hf>60K4t|KA6ggn2iU=no5SFq=VMk=kLJM< z&tXC%w4k}Ctyzti<3Lm#)lws5GvDOX%_GG-bfvkky3eTRM(XPiVk3K69)|7k{0v8w z=H7QO!|wg1p5>a(5#Yn#cBfi`Qt@3Z1wh?BZRz($ZIaOi?dfAIqQ2{3_y3 zXNG(t>>j5`+ZQ@_zNfI11j0Ys>;!j}CLfCvpHlI$A+4shsOihWrp4%GOKZ6sLLH8m zij#b8-x3`^J&P-B3}GJe?2kHc0x&A z!*N>dMjGS^N Gqy7!L;Zkh? diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index bd981a861591db2901c5931be472be7dc126039d..ac7f924b59dc6de3192c8da5e0c2f482b13a1437 100644 GIT binary patch delta 4095 zcmZ`6YfPKh^}FBa$L9wI3>fnWI6weHco;}{BoGL%q#yIuY5mn!+oY;#SNX@*O>EV=WJqDw_;vfeNl6^u~zkxEK zXH9gF1`MPj5E`Z<2Ff8Yd6G#r`>eJi%S<-HJzMbGKJ zvq|QovD7Myt;Z5!J;}Bz%aYM}awQz!>Nbcj3X<<(oo%k_n=Bmn%_bA6SW1s1Rxj>0 zH9o4y1+x42zDvYDg|T3A6F1@Z)--u>_vzLlLay^qeDhg%t}8F)x?%tRJV{HT9=%sP z2siHr=gJ0ppWd$xt=pc~+*LZE$vIsw~so1j zreGX{F&#zKGx~8bH49_zk__{Cec=+GXps-KB}kCx+s-@w1rGfOdUdOZw{&^lX|*u5WC@>E=RlI6}~ZFf_*=K{*pd^agtyP=2$=Yi4gE_I}SS znb~dZJuJiKzR~vx`GCLEQp0ce*~!Z8n|(i$$p&8^yxII$5N4+kpr6=Z0QjwJ1o4mg zL|}k?D%NEJPI8$)6F5n}+`SX{pDd@%JUmjzKN~ro-GRV>ozxyBc?_A(5~Zt5>C*eC17uuU>vIxy0@jY$^QKG1LQ;o37;Wk5 zJO*Ju6c?nmX9dwMhyx`h{beNsMkTP#L0YjUUoo*SP+g`56=Z4w1~RB@DB}UxpQ#5p zknw_v;dG72`VoQG3LG*Q=yiF5YjWA)vcQ>DYA2- zLJdT3jogzwDrmT^d$dVGlQgiSjdzxnVwR#QG7_~vIt>VueeuZ4)4~#A> z`3bv<7t8^iVHE(kY-hu3kkH}zN?@+BOH{g(Dr;5WxldNKbB zWI1SxC00`mD^NjRNwNQO_r$3tu}86u4^f-Y4y{{aR>G@29?K5_PhIBEPMmINBWD+X;O>1WfzM~ut*h{ z88Dkzpi+}zxovM&TwYp@nFbm(M~NzhFyovMDCVv1*>HSyGY-Ynw>WdghowfeUj1AI z+Hcm-e;N2pT3~+zxa_YehK1AEx|tI_yaaq!ftVeE0|8b(_6CA+0Kv*)K?{j^fT~?4 zh8w*a19)0G)`v13-}>G8zGq6qUG$3j)cmz>s?oBR$X=c=eQ*-#F^aV!Z1T z#o!^D43!T9M?tJu`E-*8fubQ_$yxGJRHh0%wB!saaUf(g^*)Xa9E=2e8do_W#FU(q zLVQpQuFO?HctSu5{s9yqAOJ*%lve6tP-~3@Unzv;)&(J<@>7$(mMw5xDju2%C6ck! zMJNC|NTxD1kZQMJhK) zkQLyCoGDGH0$E|kq(VuNLzV=jA{-PF4j_6Ge2F@)5a!fv#S8(e6TEJsX7IXPq-M1W zAiK(B%x_7n#O%&hrO|k`a66cT1#SaX0bkcp3tWbb3cXSWuUFEr2UX*Sy?76T(p3D` z!B)*;yRbo%q~SKRl2!*D6qjNqJG~NKS{_a{D#?2%@sv(72j;P%K|9l&*nP#U0cNhe zTrWZSu~2m!z6mH;nW9G;feU2`{Q=NgD*^&$O)zV$l^P*Fifm7j6>0`XGsKFGo2T0r z@8VtIpQsSaOhHsda1sGtrKlSIeWGYUau?tGP;ZKrLDY>m_|CL{?q2waxDp|CFNW+Z zVSV*nD22~7@uU(#R9}vy7tqw_{NL05nQ{jRTH`~-chsKURXxGp2g%aYCukE%j)jGJ zQB?ziN`7vp%l#gRbzwYWx=V_aRPDYz^E#<%LJ_Qlv1#Bg`xW+geqpw)V+Tk~MWzjb zIGCt~*hJYrq_yzdv%c&XIESB|Qo*qiVbJBmIy_O?B#;XV{NQ1t3*JI0agzPGR5Y|w z8-wjvin<7I<4`<&F~U+`X}A>@2A!gurZ&JxIL_LMR4E1R#ReZ*P3Q~~dTxiDyrIq$ z(!^~Gt(G>lx5N7vn$7GNunYd}g=x~suPvN5w*YIMf3mPL^)zlTfZ*2%ev9Bw2*f+{ zGsM1w0Lvu9)Mn^U)`hxT2)};9!*h!@^L^@s>QVQrF4e2r)$?kD>QLL%4we96e{;c{ t*j(SZ$b`rKjO%}fKs@wS_*uAmu0N4rxA{L8FFLZw!Aik?!qX?#{tqzIIFbMW delta 3559 zcmZ`+X>41^5q|sLBgIQMMbSEGNtR5@mPJ{%4j;00*p{O>ZE6&$(w3mLw4!H;QkG9i z?^?JP0c1y!AIEkTB9GQO=u(hbM9^BpCU9$J?j!DAK-&4j8>5?%W0q0(|X!KjGTy>&WBcX8u^qK9+snZ_& zhS#m=*)S&GmSZU`+70A(Aa~?giL_@H2BSNH@{MbsYe}2nc4obHZ;G8;$K~MgDtqz| zcEsw-?#gK;`u&jYw)Q|PkR8Zrn13(>GK4mKO6P$UtS-8D9Td6``X9``bxiyJ#q5Wf zy=eQ;_RrGH2ka5dbP(;(IUW5`d+eNpRx~cJwUQF8dui3bk?=@-S8l*xu&ex zA~u&0-vcjh$a~uY<wm`OXnR0^7qQ{BM_)Q-rx?lHE1 z?X~VbI(&I|$7Af*a(_=V`~2GDJy&%$e<^ak>Ca@6f~I01-USyf6a6H(E#HppWbexl zBdzSbZ0?(6KU|yb`$%_iPs-t!BmGVC-0ppuEH+cK!m9=xHmq&gZ8?oQPtzWDW2Af4 z5Lw%e%+;(;?ww^h+UCxpq-$N$g#%RwOa)q7syau>v7;z!SF#!<>j94p=f2bWswv(| zX&I{`T?=7*+7H^Brb!~{0O*c%J!nrlh-vy#wdxSecve z#RnBn2V?*Z`qm&-NOdlT(94h*QwIQQpK-w@=BBet@QZl)L zJ;2>4E0t%a9$ij^PA!V3V}gg4Vn|hc;AIr=>cPW9-(uoL+PLRIjtd{Lsi&VzivP$Dh97EF-YEw*CdXx>36sg{ zK*0GG@X7Nsx%c6^PS&Qom`nGt2IgWltj6#&zx>VK$&86Psg0o}1qgN+&0<+Sr(yR| zX0t{L)G~9Z?C3dCI!DT8ohXf3v`md_maBk(#H}!z<>opxjFeGCdVm1h^3-eG@RP2W zQq_w+G?$xU`Tf^(cV7S6?Tc4$U;Os%m#*FU`uWh2L&rn!t-Nyk>h;4X#9x64`@~=H z5f_O}5xGI+BO-4R83LL17Ro}Dwn9aI+Px&=T!lD)e~{uWj=v3j*i@%GZxTLVfGX8$ zislxYR*_EPehPArB1;8QGpbkxtI=yfJ)m;u!Jz!&fl%hl7_$*YtQRq}jI5c{W)U~8 zujpx39C2X8`M`*d5sIiwMckdzb4J#~jU3}9X=c4SjqcXEtg+Y_q2|*F;APq!XiB>EYv-NBH`q?-!st)P7i`g zHCYu+vC6~&n5u@P2F*nDb?!v;okb~CNdLN&GE=weRK5pNPV~GfC!z~;z}Lfp`_eK< zgS-rT!f3vrmUg8wq`ZlgINPKJX+xBxvGApI#|jG$nwnK`6OqW414r&P1U2re#ewr{ z1Dbw+A&td_JO%SjK+$YChQ(epdj6bVuch5nDK~cyfk~L$Oe2Ia)q-KVI&`ZFw9>CK z1HD%CT2v1w75&D-NHm0A2r1){9~^4S2<#zpgvj4O?x3mP3E{$5OcL=CQKxdVD3nSP zP!SbWxC$Aj-VvsVCAjzy!n~2R`Lt&~ws4ZiqB!3ldw@NXkGAfno41gw8yV_38M9}e zh$iX5q2305N?w(sS%P1bKN<~BP*t2Ro9~NKAy~IIV<*HfF%tF^ZgUT*Do<;OYamif zq>e~HM#j1tegS-dlD!9vX82G?W%X5|m;vcMzx(VMF{~4o2^o+jd&<=Kf)hTcK~k5Mh}U5X1_1u=VWF(LLtoHrKm0#aASn}kv^ z4w0&=`o&t0^;8z>mn(7qh4*oWTc9}YQXUWY7 zkslNJ6_Gy@QEw_$xj#c)Dx!kwh@kk2PO{M~M<=&Ubh`#!0oP_%g)8XtxR&tsxVE}F zL;|>QQ{ItSURZiYsG$9cRKHI|<(SJ}6f;lsBog8W^7`a6-m^49HA%c7cON_bUn!-S A=l}o! diff --git a/models/networks.py b/models/networks.py index 74343e6..7519e80 100644 --- a/models/networks.py +++ b/models/networks.py @@ -331,6 +331,8 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal' net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) elif 'stylegan2' in netD: net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt) + elif netD == 'basic_cond': # more options + net = NLayerDiscriminator_ncsn(input_nc, ndf, n_layers=3, norm_layer=norm_layer, no_antialias=no_antialias) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) return init_net(net, init_type, init_gain, gpu_ids, diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 2c275a7..8dbc273 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -200,6 +200,8 @@ class RomaUnsbModel(BaseModel): parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)') parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss') parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') + parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') + parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers') @@ -220,7 +222,7 @@ class RomaUnsbModel(BaseModel): parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions') - parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers') + parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter') parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer') @@ -258,7 +260,7 @@ class RomaUnsbModel(BaseModel): self.visual_names += ['idt_B'] if self.isTrain: - self.model_names = ['G', 'D', 'E'] + self.model_names = ['G', 'D_ViT', 'E'] else: @@ -269,23 +271,25 @@ class RomaUnsbModel(BaseModel): self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) - if self.isTrain: - self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) + if self.isTrain: self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) self.resize = tfs.Resize(size=(384,384), antialias=True) + self.netD_ViT = networks.MLPDiscriminator().to(self.device) + # 加入预训练VIT self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device) # 定义损失函数 + self.criterionL1 = torch.nn.L1Loss().to(self.device) self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) self.criterionNCE = [] for nce_layer in self.nce_layers: self.criterionNCE.append(PatchNCELoss(opt).to(self.device)) self.criterionIdt = torch.nn.L1Loss().to(self.device) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) - self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) + self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizers = [self.optimizer_G, self.optimizer_D, self.optimizer_E] @@ -320,10 +324,10 @@ class RomaUnsbModel(BaseModel): self.netG.train() self.netE.train() - self.netD.train() + self.netD_ViT.train() # update D - self.set_requires_grad(self.netD, True) + self.set_requires_grad(self.netD_ViT, True) self.optimizer_D.zero_grad() self.loss_D = self.compute_D_loss() self.loss_D.backward() @@ -337,7 +341,7 @@ class RomaUnsbModel(BaseModel): self.optimizer_E.step() # update G - self.set_requires_grad(self.netD, False) + self.set_requires_grad(self.netD_ViT, False) self.set_requires_grad(self.netE, False) self.optimizer_G.zero_grad() @@ -443,7 +447,7 @@ class RomaUnsbModel(BaseModel): # ============ 第三步:拼接输入并执行网络推理 ============= bs = self.real_A0.size(0) - z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.real_A0.device) + z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device) z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device) # 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB self.real = self.real_A0 @@ -455,9 +459,10 @@ class RomaUnsbModel(BaseModel): self.real = torch.flip(self.real, [3]) self.realt = torch.flip(self.realt, [3]) - + print(f'fake_B0: {self.real_A0.shape}, fake_B1: {self.real_A1.shape}') self.fake_B0 = self.netG(self.real_A0, self.time, z_in) self.fake_B1 = self.netG(self.real_A1, self.time, z_in2) + print(f'fake_B0: {self.fake_B0.shape}, fake_B1: {self.fake_B1.shape}') if self.opt.phase == 'train': real_A0 = self.real_A0 @@ -507,23 +512,35 @@ class RomaUnsbModel(BaseModel): #self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True) - def compute_D_loss(self): - """计算判别器的 GAN 损失""" - - fake = self.cat_results(self.fake_B.detach()) - pred_fake = self.netD(fake, self.time) - self.loss_D_fake = self.criterionGAN(pred_fake, False).mean() - - self.pred_real = self.netD(self.real_B0, self.time) - loss_D_real = self.criterionGAN(self.pred_real, True) - self.loss_D_real = loss_D_real.mean() - - self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 - return self.loss_D + def compute_D_loss(self): #判别器还是没有改 + """Calculate GAN loss for the discriminator""" + + lambda_D_ViT = self.opt.lambda_D_ViT + fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach() + fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach() + + real_B0_tokens = self.mutil_real_B0_tokens[0] + real_B1_tokens = self.mutil_real_B1_tokens[0] + + + pre_fake0_ViT = self.netD_ViT(fake_B0_tokens) + pre_fake1_ViT = self.netD_ViT(fake_B1_tokens) + + self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT + + pred_real0_ViT = self.netD_ViT(real_B0_tokens) + pred_real1_ViT = self.netD_ViT(real_B1_tokens) + self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT + + self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5 + + + return self.loss_D_ViT def compute_E_loss(self): """计算判别器 E 的损失""" + print(f'resl_A_noisy: {self.real_A_noisy.shape} \n fake_B0: {self.fake_B0.shape}') XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1) XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1) temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean() @@ -534,14 +551,8 @@ class RomaUnsbModel(BaseModel): def compute_G_loss(self): """计算生成器的 GAN 损失""" - bs = self.real_A0.size(0) - tau = self.opt.tau - - fake = self.fake_B0 - std = torch.rand(size=[1]).item() * self.opt.std - if self.opt.lambda_GAN > 0.0: - pred_fake = self.netD(fake, self.time) + pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0]) self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN else: self.loss_G_GAN = 0.0 @@ -555,7 +566,7 @@ class RomaUnsbModel(BaseModel): # eq.9 ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0) self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY - self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2) + self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2) if self.opt.lambda_global > 0.0: loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1) diff --git a/options/__pycache__/base_options.cpython-39.pyc b/options/__pycache__/base_options.cpython-39.pyc index 55ab7e1a65961dcda63124aa2f74d66d46317a27..7658be9386c7c6039866b5ccc650f594579220ce 100644 GIT binary patch delta 327 zcmZ2yy~~<6k(ZZ?0SGv)cBik~$g9RF$(58?oS7V-oS&DX$zCK1l)J?W7TWBgfl+Gm0B+3ziy7rMujGEq$!IWnx@e-LK1e-RL1IyHYJ6f*da)))5j#+* zNPn}S*ls39&BudO>M?W=e4p z8%RiJvXP9IkPVQj$p&!}h+PymIZtMhfB}dJHWj3^C~mTXtfjmWi0K9*m_S4nh%g5c PRv>~Is78EpzpOj}7!Xc} delta 306 zcmdmGz0R69k(ZZ?0SG*Q?n%Kes;QGfDc(L_l-pu$^R1&Kw)squ+L>BX8HMeIPKBE8MBV!N3b zH74Jb*rxzi#GRg65}%TqmROooQltga#+H^}l$%&mWHEWWWD2A8W*MoQjGh`G-8Mkt z7o$!QBaqe?_nnJL9ZY#<@+$#ybY zLe@Yg&^ur!f!IZ%ldEJF3Fw2EU{gUli()71$XZGnf|#x#f(b-Kf(SDZVL7>8)*1i= Cnn|Jn diff --git a/options/base_options.py b/options/base_options.py index f9de39b..b20e1b4 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -35,7 +35,7 @@ class BaseOptions(): parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') - parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') + parser.add_argument('--netD', type=str, default='basic_cond', choices=['basic_cond', 'basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') parser.add_argument('--netG', type=str, default='resnet_9blocks_cond', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture') parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G') diff --git a/scripts/train.sh b/scripts/train.sh index 93a5f96..dea6a51 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -28,6 +28,6 @@ python train.py \ --num_patches 256 \ --flip_equivariance False \ --eta_ratio 0.1 \ - --tau 0.1 \ + --tau 0.01 \ --num_timesteps 10 \ --input_nc 3